package theorycrafter.fitting.utils


/**
 * A multi-set is like a set, except that it can contain the same item multiple times.
 *
 * Note that this collection is not threadsafe.
 * Also note that it hasn't been extensively tested.
 */
interface MultiSet<E> : Collection<E> {


    /**
     * Returns the set of values, each with their amount in the multi-set.
     */
    val valuesWithAmounts: Set<Map.Entry<E, Int>>


    /**
     * Returns the number of times an item appears in the multi-set; 0 if none.
     */
    operator fun get(element: E): Int


}


/**
 * A mutable variant of [MultiSet].
 */
interface MutableMultiSet<E> : MultiSet<E>, MutableCollection<E>


/**
 * Returns a new, initially empty, multi-set.
 */
fun <T> mutableMultiSetOf(): MutableMultiSet<T> = MultiSetImpl()


/**
 * Returns a new [MutableMultiSet] containing the items in the given collection.
 */
fun <T> Collection<T>.toMutableMultiSet(): MutableMultiSet<T> {
    val set = MultiSetImpl<T>()
    set.addAll(this)
    return set
}


/**
 * An implementation of a mutable multi-set.
 */
class MultiSetImpl<E> : MutableMultiSet<E> {


    /**
     * The underlying counting map - the number of times each item appears in the multi-set.
     *
     * The value should never be 0; when the count reaches 0, the key should be removed.
     */
    private val itemCount: MutableMap<E, Int> = mutableMapOf()


    /**
     * The total number of items in the multi-set.
     *
     * This should always equal to the sum of values in [itemCount].
     */
    override var size: Int = 0


    override fun contains(element: E) = itemCount.containsKey(element)


    override fun get(element: E): Int = itemCount.getOrDefault(element, 0)


    // Note that this doesn't take into account the number of items an element appears in `elements`.
    // For example, if the multi-set contains "foo" once, then `containsAll(listOf("foo", "foo"))` will be `true`.
    override fun containsAll(elements: Collection<E>): Boolean = itemCount.keys.containsAll(elements)


    override fun isEmpty() = itemCount.isEmpty()


    override fun iterator(): MutableIterator<E> = MultiSetMutableIterator()


    override fun add(element: E): Boolean {
        itemCount.merge(element, 1, Int::plus)
        size += 1
        return true
    }


    override fun addAll(elements: Collection<E>): Boolean {
        for (element in elements) {
            add(element)
        }
        return true
    }

    override fun clear() {
        itemCount.clear()
        size = 0
    }

    override fun remove(element: E): Boolean {
        val oldValue = itemCount[element]
        if (oldValue == null)
            return false

        if (oldValue == 1)
            itemCount.remove(element)
        else
            itemCount[element] = oldValue - 1

        size -= 1

        return true
    }


    override fun removeAll(elements: Collection<E>): Boolean {
        val initialSize = size
        for (element in elements) {
            remove(element)
        }
        return size < initialSize
    }


    override fun retainAll(elements: Collection<E>): Boolean {
        val counts = elements.groupingBy { it }.eachCount()
        val iterator = itemCount.iterator()
        val initialSize = size
        while (iterator.hasNext()) {
            val entry = iterator.next()
            val updatedCount = counts.getOrDefault(entry.key, 0)
            if (updatedCount < entry.value) {
                size -= (entry.value - updatedCount)
                if (updatedCount == 0)
                    iterator.remove()
                else
                    entry.setValue(updatedCount)
            }
        }

        return size < initialSize
    }


    /**
     * Returns the values and their amounts.
     */
    override val valuesWithAmounts: Set<Map.Entry<E, Int>> = itemCount.entries


    /**
     * A [MutableIterator] over the elements in this [MultiSet].
     */
    private inner class MultiSetMutableIterator : MutableIterator<E> {


        /**
         * The iterator over entries in [itemCount].
         */
        val countIterator = itemCount.iterator()


        /**
         * The entry whose item we're currently returning.
         */
        private var currentEntry: MutableMap.MutableEntry<E, Int>? = null


        /**
         * The number of items we've already returned from [currentEntry].
         */
        private var currentEntryReturnedCount = 0


        /**
         * Whether we've already been asked to remove the last value we returned.
         */
        private var removedLastReturned = false


        override fun hasNext(): Boolean {
            return countIterator.hasNext() || currentEntry.let { (it != null) && (currentEntryReturnedCount < it.value) }
        }


        override fun next(): E {
            val entry = currentEntry
            if ((entry == null) || (currentEntryReturnedCount == entry.value)) {
                currentEntry = countIterator.next()
                currentEntryReturnedCount = 0
            }
            removedLastReturned = false

            currentEntryReturnedCount += 1
            return currentEntry!!.key
        }


        override fun remove() {
            if (removedLastReturned) {
                throw NoSuchElementException()
            }

            currentEntry.let {
                if (it == null)
                    throw NoSuchElementException()
                it.setValue(it.value - 1)
                size -= 1
            }
            removedLastReturned = true
        }


    }


}


