-- parstat.opm
-- Copyright 2025 Petr Krajník
--
-- This work may be distributed and/or modified under the
-- conditions of the LaTeX Project Public License, either version 1.3
-- of this license or (at your option) any later version.
-- The latest version of this license is in
--   https://www.latex-project.org/lppl.txt
-- and version 1.3c or later is part of all distributions of LaTeX
-- version 2008 or later.
--
-- This work has the LPPL maintenance status 'author-maintained'.
--
-- The Current Maintainer of this work is Petr Krajník.
--
-- This work consists of the files parstat.opm and parstat.lua.
--

-- Node types
local NODE_HLIST = node.id("hlist")
local NODE_VLIST = node.id("vlist")
local NODE_GLYPH = node.id("glyph")
local NODE_GLUE  = node.id("glue")

-- Node subtypes
local GLYPH_LIGATURE_FLAG = 2
local GLYPH_GHOST_FLAG = 4
local GLUE_SPACE = 13
local GLUE_XSPACE = 14

-- Line statistic outputs per console line
local STATS_PER_LINE = 5

-- Statistic structure factory
local function stat_new(name)
    return {
        name = name,
        count = 0,
        average = 0,
        M2 = 0, -- Sum of squared diffs for stddev
        min = math.huge,
        max = 0 }
end

local function make_stats()
    return stat_new("Glyphs"), stat_new("Spaces")
end

-- Total statistics
local glyph_stat_total, space_stat_total = make_stats()
local par_count = 0

-- Line skips (only valid when processing)
local front_skip, back_skip

local function init_line_skips()
    front_skip = tex.count["_parstat_fskip"]
    assert(front_skip >= 0, "negative front line skip")
    back_skip  = tex.count["_parstat_bskip"]
    assert(back_skip >= 0, "negative back line skip")
end

local function enough_lines(line_count)
    return (front_skip + back_skip) < line_count
end

local function is_line_ignored(line, line_count)
    return line <= front_skip or line > (line_count - back_skip)
end

local function print_header()
    texio.write_nl(string.format(
        "Parstat %d in '%s' at line %d",
        par_count, status.filename, status.linenumber))
end

local function is_real_glyph(n)
    return n.id == NODE_GLYPH and (n.subtype & GLYPH_GHOST_FLAG) == 0
end

local function is_ligature(n)
    return (n.subtype & GLYPH_LIGATURE_FLAG) ~= 0
end

local function is_space(n)
    return n.id == NODE_GLUE and (n.subtype == GLUE_SPACE
        or n.subtype == GLUE_XSPACE)
end

local function is_nested_node(n)
    return n.id == NODE_HLIST or n.id == NODE_VLIST
end

local function count_glyphs(head)
    local glyph_count = 0
    local space_count = 0

    for n in node.traverse(head) do
        if is_real_glyph(n) then
            if is_ligature(n) and n.components then
                -- Count Ligature components (see limitations)
                local gc, sc =  count_glyphs(n.components)
                glyph_count = glyph_count + gc
                space_count = space_count + sc
            else
                glyph_count = glyph_count + 1
            end
        elseif is_space(n) then
            glyph_count = glyph_count + 1
            space_count = space_count + 1
        elseif is_nested_node(n) then
            local gc, sc =  count_glyphs(n.head)
            glyph_count = glyph_count + gc
            space_count = space_count + sc
        end
    end
    return glyph_count, space_count
end

local function get_ignore_mark(line_ignored)
    if line_ignored then
        return "X" -- Ignored lines are marked
    end
    return ""
end

local function print_line_stat(line, line_ignored, glyph_count, space_count)
    if (line % STATS_PER_LINE) == 1 then
        texio.write_nl(" ") -- Limit line stats per line
    end
    texio.write(string.format(" %d%s:(%dg,%ds)", line,
        get_ignore_mark(line_ignored), glyph_count, space_count))
end

local function stat_add(stat, value)
    stat.count = stat.count + 1

    -- Average and stddev (Welford’s Algorithm)
    local delta_old = value - stat.average
    stat.average = stat.average + delta_old / stat.count
    local delta_new = value - stat.average
    stat.M2 = stat.M2 + delta_old * delta_new

    stat.min = math.min(stat.min, value)
    stat.max = math.max(stat.max, value)
end

local function stat_stddev(stat)
    if stat.count > 1 then
        return math.sqrt(stat.M2 / (stat.count - 1))
    end
    return 0.0
end

local function stat_min(stat)
    if stat.min == math.huge then
        return 0
    end
    return stat.min
end

local function stat_print(stat)
    texio.write_nl(string.format(
        "  %s: ave %.2f, stddev %.2f, min %d, max %d", stat.name,
        stat.average, stat_stddev(stat), stat_min(stat), stat.max))
end

local function print_par_stat(head, line_count)
    local glyph_stat, space_stat = make_stats()
    local curr_line = 0

    for line in node.traverse_id(NODE_HLIST, head) do
        curr_line = curr_line + 1
        local line_ignored = is_line_ignored(curr_line, line_count)

        -- Single line stat
        local glyph_count, space_count = count_glyphs(line.head)
        print_line_stat(curr_line, line_ignored, glyph_count, space_count)

        -- Calculate par stats
        if not line_ignored then
            stat_add(glyph_stat, glyph_count)
            stat_add(space_stat, space_count)
            stat_add(glyph_stat_total, glyph_count)
            stat_add(space_stat_total, space_count)
        end
    end
    assert(curr_line == line_count)

    stat_print(glyph_stat)
    stat_print(space_stat)
    texio.write_nl("")
end

--------------------------------

local parstat = _ENV.parstat or {}
_ENV.parstat = parstat

function parstat.run(head)
    if tex.count["_parstat_enabled"] > 0 then
        init_line_skips()
        local line_count = node.count(NODE_HLIST, head)

        if enough_lines(line_count) then
            par_count = par_count + 1
            print_header()
            print_par_stat(head, line_count)
        end
    end
    return head
end

function parstat.print_summary()
    texio.write_nl(string.format(
        "Parstat summary from '%s' at line %d",
        status.filename, status.linenumber))
    texio.write_nl(string.format(
        "  Analyzed %d paragraphs, %d lines in total",
        par_count, glyph_stat_total.count))
    stat_print(glyph_stat_total)
    stat_print(space_stat_total)
    texio.write_nl("")
end

function parstat.reset()
    glyph_stat_total, space_stat_total = make_stats()
    par_count = 0
end
