import type { LegacyRef } from "react"
import { useCallback, useEffect, useRef } from "react"
import { getTabIndexOfNode, sortByTabIndex } from "../utils/domUtils"

const focusableElementsSelector =
    "a[href], area[href], input:not([disabled]):not([type=hidden]), select:not([disabled]), textarea:not([disabled]), button:not([disabled]), iframe, object, embed, *[tabindex], *[contenteditable]"

export function useFocusTrap(): [LegacyRef<HTMLElement> | undefined] {
    const trapRef = useRef<HTMLElement>(null)

    const selectNextFocusableElem = useCallback(
        (
            sortedFocusableElems: Element[],
            currentIndex: number | undefined,
            shiftKeyPressed = false,
            skipCount = 0,
        ) => {
            if (skipCount > sortedFocusableElems.length) {
                // this means that it ran through all of elements but non was properly focusable
                // hence we stop it to avoid running in an infinite loop
                return false
            }

            const backwards = !!shiftKeyPressed
            const maxIndex = sortedFocusableElems.length - 1

            // eslint-disable-next-line @typescript-eslint/strict-boolean-expressions
            if (!currentIndex) {
                currentIndex =
                    sortedFocusableElems.indexOf(document.activeElement!) ?? 0
            }

            let nextIndex = backwards ? currentIndex - 1 : currentIndex + 1
            if (nextIndex > maxIndex) {
                nextIndex = 0
            }

            if (nextIndex < 0) {
                nextIndex = maxIndex
            }

            const newFocusElem = sortedFocusableElems[nextIndex]

            ;(newFocusElem as HTMLElement).focus()

            if (document.activeElement !== newFocusElem) {
                // run another round
                selectNextFocusableElem(
                    sortedFocusableElems,
                    nextIndex,
                    shiftKeyPressed,
                    skipCount + 1,
                )
            }
        },
        [],
    )

    // defining the trap function first
    const trapper = useCallback((evt: KeyboardEvent) => {
        const trapRefElem: HTMLElement | null = trapRef.current
        if (trapRefElem !== null) {
            if (evt.key === "Tab") {
                evt.preventDefault()
                const shiftKeyPressed = !!evt.shiftKey
                let focusableElems = Array.from(
                    trapRefElem.querySelectorAll(focusableElementsSelector),
                ).filter(
                    (focusableElement) =>
                        getTabIndexOfNode(focusableElement) >= 0,
                ) // caching this is NOT a good idea in dynamic applications - so don't!
                // now we need to sort it by tabIndex, highest first
                focusableElems = focusableElems.sort(sortByTabIndex)

                selectNextFocusableElem(
                    focusableElems,
                    undefined,
                    shiftKeyPressed,
                )
            }
        }
        // eslint-disable-next-line react-hooks/exhaustive-deps
    }, [])

    useEffect(() => {
        window.addEventListener("keydown", trapper)

        return () => {
            window.removeEventListener("keydown", trapper)
        }
        // eslint-disable-next-line react-hooks/exhaustive-deps
    }, [])

    return [trapRef]
}
