Python script to extract common elements from YAML entries into a derivable base element

This script was developed for optimizing kicad-footprint-generator YAML definitions.

It scans a footprint YAML file, finds values that are identical across most entries, and moves them into a single shared base element. Each individual footprint then only keeps the parameters that actually differ from the base, which keeps the file shorter and reduces duplication.

For example, given a minimal input like this:

original.yaml
Vertical:
    pad_width: 1.0
    pad_height: 2.0
    solder_mask_margin: 0.05

Horizontal:
    pad_width: 1.0
    pad_height: 2.0
    solder_mask_margin: 0.05
    special_option: true

the script can extract the common values into a base and rewrite the file roughly as:

processed.yaml
base: &base
    pad_width: 1.0
    pad_height: 2.0
    solder_mask_margin: 0.05

Vertical:
    <<: *base

Horizontal:
    <<: *base
    special_option: true

This is essentially what it does for the much larger KiCad generator YAMLs.

yaml_extract_common_base.py
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2025 Uli Köhler <[email protected]>
# SPDX-License-Identifier: CC0-1.0
"""Extract and consolidate common values from KiCad generator YAML files."""

from __future__ import annotations

import argparse
import copy
import importlib
import io
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

try:
    _ruamel_yaml = importlib.import_module("ruamel.yaml")
    _ruamel_comments = importlib.import_module("ruamel.yaml.comments")
except ImportError as exc:  # pragma: no cover - dependency missing at runtime
    raise SystemExit(
        "The ruamel.yaml package is required. Install it via 'pip install ruamel.yaml'."
    ) from exc

YAML = _ruamel_yaml.YAML
CommentedMap = _ruamel_comments.CommentedMap
CommentedSeq = _ruamel_comments.CommentedSeq

BASE_COMMENT_MARKER = "Base configuration extracted by YAMLCommon"


RT_YAML = YAML(typ="rt")
RT_YAML.preserve_quotes = True
RT_YAML.allow_duplicate_keys = True
RT_YAML.width = 4096
RT_YAML.indent(mapping=2, sequence=4, offset=2)


@dataclass
class Entry:
    """Represents a single footprint entry in the custom YAML file."""

    leading_lines: List[str]
    separator_lines: List[str]
    name_line: Optional[str]
    after_name_lines: List[str]
    root_key: str
    inner_map: CommentedMap
    outer_map: CommentedMap
    inherits_base: bool = False
    defines_anchor: bool = False


@dataclass
class ParsedFile:
    """Holds the structural pieces of the parsed file."""

    prefix_lines: List[str]
    entries: List[Entry]
    suffix_lines: List[str]
    existing_base: Optional[CommentedMap] = None


def is_block_start(line: str) -> bool:
    """Return True if *line* marks the start of a footprint block."""

    stripped = line.strip()
    if not stripped or line.startswith(" "):
        return False
    if ":" not in stripped:
        return False
    key = stripped.split(":", 1)[0].strip()
    return bool(key)


def is_name_line(line: str) -> bool:
    """Heuristic that recognises the footprint name line."""

    stripped = line.strip()
    if not stripped or ":" in stripped:
        return False
    if stripped.startswith("#") or stripped.startswith("-"):
        return False
    return line.startswith("  ")


def parse_file(text: str, anchor_name: str) -> ParsedFile:
    """Parse the non-standard YAML file into structured entries."""

    lines = text.splitlines(keepends=True)
    idx = 0
    prefix_lines: List[str] = []
    suffix_lines: List[str] = []
    entries: List[Entry] = []
    first_block = True

    while True:
        interim: List[str] = []
        while idx < len(lines) and not is_block_start(lines[idx]):
            interim.append(lines[idx])
            idx += 1

        if idx >= len(lines):
            if first_block:
                prefix_lines = interim
            else:
                suffix_lines = interim
            break

        if first_block:
            prefix_lines = interim
            leading_lines: List[str] = []
            first_block = False
        else:
            leading_lines = interim

        block_lines: List[str] = []
        while idx < len(lines) and lines[idx].strip() != "":
            block_lines.append(lines[idx])
            idx += 1
        defines_anchor = any(f"&{anchor_name}" in line for line in block_lines)
        block_lines, inherits_base = strip_merge_marker(block_lines, anchor_name)
        block_text = "".join(block_lines)
        if not block_text.strip():
            raise ValueError("Encountered empty block; the file format might be unsupported.")

        separator_lines: List[str] = []
        while idx < len(lines) and lines[idx].strip() == "":
            separator_lines.append(lines[idx])
            idx += 1
        while idx < len(lines) and lines[idx].lstrip().startswith("#"):
            separator_lines.append(lines[idx])
            idx += 1

        name_line: Optional[str] = None
        if idx < len(lines) and is_name_line(lines[idx]):
            name_line = lines[idx]
            idx += 1

        after_name_lines: List[str] = []
        while idx < len(lines) and not is_block_start(lines[idx]):
            after_name_lines.append(lines[idx])
            idx += 1

        data = RT_YAML.load(block_text)
        if not isinstance(data, CommentedMap) or len(data) != 1:
            raise ValueError("Each block must be a mapping with a single root key (e.g. Vertical).")

        root_key = next(iter(data))
        inner_map = data[root_key]
        if not isinstance(inner_map, CommentedMap):
            raise ValueError("Top-level entries must map to a dictionary of parameters.")

        entries.append(
            Entry(
                leading_lines=leading_lines,
                separator_lines=separator_lines,
                name_line=name_line,
                after_name_lines=after_name_lines,
                root_key=root_key,
                inner_map=inner_map,
                outer_map=data,
                inherits_base=inherits_base,
                defines_anchor=defines_anchor,
            )
        )

    return ParsedFile(prefix_lines=prefix_lines, entries=entries, suffix_lines=suffix_lines)


def strip_merge_marker(block_lines: List[str], anchor_name: str) -> tuple[List[str], bool]:
    """Remove merge references to the shared anchor from a YAML block."""

    filtered: List[str] = []
    removed = False
    target = f"<<: *{anchor_name}"
    for line in block_lines:
        if line.strip() == target:
            removed = True
            continue
        filtered.append(line)
    return filtered, removed


def strip_generated_base(parsed: ParsedFile, anchor_name: str) -> ParsedFile:
    """Remove any previously generated base block and associated comment."""

    clean_prefix = [line for line in parsed.prefix_lines if BASE_COMMENT_MARKER not in line]
    filtered_entries: List[Entry] = []
    base_map: Optional[CommentedMap] = parsed.existing_base
    for entry in parsed.entries:
        if entry.defines_anchor:
            base_map = entry.inner_map
            continue
        filtered_entries.append(entry)
    return ParsedFile(
        prefix_lines=clean_prefix,
        entries=filtered_entries,
        suffix_lines=parsed.suffix_lines,
        existing_base=base_map,
    )


def ensure_base_inheritance(parsed: ParsedFile) -> None:
    """Mark entries as inheriting the base if they lack base keys."""

    if not parsed.existing_base:
        return
    for entry in parsed.entries:
        if entry.inherits_base:
            continue
        if not entry_supports_base(entry.inner_map, parsed.existing_base):
            entry.inherits_base = True


def normalize(value):
    """Convert ruamel nodes to plain Python objects for equality checks."""

    if isinstance(value, CommentedMap):
        return {k: normalize(v) for k, v in value.items()}
    if isinstance(value, CommentedSeq):
        return [normalize(v) for v in value]
    return value


def merged_with_base(base_map: CommentedMap, overrides: CommentedMap) -> CommentedMap:
    """Return a deep copy of *base_map* updated with *overrides*."""

    result = copy.deepcopy(base_map)
    for key, value in overrides.items():
        if (
            key in result
            and isinstance(result[key], CommentedMap)
            and isinstance(value, CommentedMap)
        ):
            result[key] = merged_with_base(result[key], value)
        else:
            result[key] = copy.deepcopy(value)
    return result


def values_equal(lhs, rhs) -> bool:
    """Deep equality that ignores ruamel-specific metadata."""

    return normalize(lhs) == normalize(rhs)


def _round_number(value, digits: Optional[int]):
    """Round numeric *value* to *digits* decimal places if requested.

    If *digits* is None, the value is returned unchanged. Non-numeric
    values are also returned unchanged.
    """

    if digits is None:
        return value
    try:
        # Handle ints/floats while leaving other types untouched
        if isinstance(value, (int, float)):
            return round(value, digits)
    except TypeError:
        pass
    return value


def _round_node(node, digits: Optional[int]):
    """Recursively round all numeric scalars within *node* in-place."""

    if digits is None:
        return node
    if isinstance(node, CommentedMap):
        for key, val in list(node.items()):
            node[key] = _round_node(val, digits)
        return node
    if isinstance(node, CommentedSeq):
        for idx, val in enumerate(list(node)):
            node[idx] = _round_node(val, digits)
        return node
    return _round_number(node, digits)



def collect_common(
    entries: List[Entry],
    threshold: float,
    fallback_base: Optional[CommentedMap] = None,
) -> CommentedMap:
    """Return a mapping with keys that are common across entries."""

    total = len(entries)
    result = CommentedMap()
    if total == 0:
        return result

    occurrences: dict[str, List[object]] = {}
    for entry in entries:
        items = entry.inner_map.items()
        if entry.inherits_base and fallback_base is not None:
            combined = merged_with_base(fallback_base, entry.inner_map)
            items = combined.items()
        for key, value in items:
            occurrences.setdefault(key, []).append(value)

    for key, values in occurrences.items():
        ratio = len(values) / total
        if ratio < threshold:
            continue
        first = values[0]
        if all(values_equal(first, other) for other in values[1:]):
            result[key] = copy.deepcopy(first)

    return result


def entry_supports_base(entry_map: CommentedMap, base_map: CommentedMap) -> bool:
    """Check if *entry_map* contains all keys found in *base_map*."""

    for key, base_value in base_map.items():
        if key not in entry_map:
            return False
        entry_value = entry_map[key]
        if isinstance(base_value, CommentedMap) and isinstance(entry_value, CommentedMap):
            if not entry_supports_base(entry_value, base_value):
                return False
        elif isinstance(base_value, CommentedSeq) and isinstance(entry_value, CommentedSeq):
            if not values_equal(entry_value, base_value):
                return False
        else:
            if not values_equal(entry_value, base_value):
                return False
    return True


def build_overrides(entry_map: CommentedMap, base_map: CommentedMap) -> CommentedMap:
    """Return the subset of *entry_map* that differs from *base_map*."""

    overrides = CommentedMap()
    for key, entry_value in entry_map.items():
        if key not in base_map:
            overrides[key] = copy.deepcopy(entry_value)
            continue

        base_value = base_map[key]
        if isinstance(entry_value, CommentedMap) and isinstance(base_value, CommentedMap):
            child = build_overrides(entry_value, base_value)
            if child:
                overrides[key] = child
            continue
        if isinstance(entry_value, CommentedSeq) and isinstance(base_value, CommentedSeq):
            if not values_equal(entry_value, base_value):
                overrides[key] = copy.deepcopy(entry_value)
            continue
        if not values_equal(entry_value, base_value):
            overrides[key] = copy.deepcopy(entry_value)

    return overrides


def dump_map(node: CommentedMap) -> str:
    """Render *node* back to YAML text."""

    buffer = io.StringIO()
    RT_YAML.dump(node, buffer)
    return buffer.getvalue()


def _supports_color(stream) -> bool:
    """Return True if *stream* likely supports ANSI color codes."""

    return bool(getattr(stream, "isatty", lambda: False)()) and os.environ.get("TERM") not in {None, "", "dumb"}


def _colorize(text: str, color_code: str = "31") -> str:
    """Wrap *text* in ANSI color codes if the current stderr supports it."""

    if _supports_color(sys.stderr):
        return f"\033[{color_code}m{text}\033[0m"
    return text


def inject_merge_line(rendered: str, anchor_name: str, indent: str = "  ") -> str:
    """Insert a YAML merge reference right after the root key."""

    lines = rendered.splitlines(keepends=True)
    if not lines:
        return f"{indent}<<: *{anchor_name}\n"
    if not lines[0].endswith("\n"):
        lines[0] += "\n"
    merge_line = f"{indent}<<: *{anchor_name}\n"
    lines.insert(1, merge_line)
    return "".join(lines)


def render_with_merge(root_key: str, overrides: CommentedMap, anchor_name: str) -> str:
    """Render an entry map preceded by a merge reference."""

    if overrides:
        container = CommentedMap()
        container[root_key] = overrides
        return inject_merge_line(dump_map(container), anchor_name)
    return f"{root_key}:\n  <<: *{anchor_name}\n"


def rewrite(
    parsed: ParsedFile,
    base_map: Optional[CommentedMap],
    threshold: float,
    anchor_name: str,
    round_digits: Optional[int],
    target_path: Path,
) -> None:
    """Rewrite the YAML file with a base block followed by per-entry overrides."""

    fragments: List[str] = []
    fragments.extend(parsed.prefix_lines)

    base_inner = None
    if base_map and len(base_map):
        base_inner = copy.deepcopy(base_map)
        # Apply rounding to the shared base before anchoring, if requested.
        if round_digits is not None:
            _round_node(base_inner, round_digits)
        base_inner.yaml_set_anchor(anchor_name, always_dump=True)
        # The base element must always be named "base" in the
        # generated YAML, regardless of the original root key.
        base_container = CommentedMap()
        base_container["base"] = base_inner
        base_container.yaml_set_start_comment(
            f"Base configuration extracted by YAMLCommon (threshold={threshold:.2f})",
            indent=0,
        )
        fragments.append(dump_map(base_container))
        fragments.append("\n")

    for entry in parsed.entries:
        fragments.extend(entry.leading_lines)
        # If we have a generated base and the entry either explicitly supports
        # that base or was previously marked as inheriting from the base
        # (e.g. because it was missing base keys), always render the entry
        # using a merge reference. This ensures every entry that removed
        # values relative to the base is still declared as deriving from it.
        entry_map_for_write = copy.deepcopy(entry.outer_map)
        if base_inner and (entry_supports_base(entry.inner_map, base_inner) or entry.inherits_base):
            # Build overrides relative to the (already rounded) base.
            inner_copy = copy.deepcopy(entry.inner_map)
            if round_digits is not None:
                _round_node(inner_copy, round_digits)
            overrides = build_overrides(inner_copy, base_inner)
            # If there are no overrides (i.e. the entry is identical to the
            # base), still emit a merge so the entry remains present.
            fragments.append(render_with_merge(entry.root_key, overrides, anchor_name))
        else:
            # No base or unsupported by base: round the whole entry mapping.
            if round_digits is not None:
                _round_node(entry_map_for_write, round_digits)
            fragments.append(dump_map(entry_map_for_write))
        fragments.extend(entry.separator_lines)
        if entry.name_line:
            fragments.append(entry.name_line if entry.name_line.endswith("\n") else entry.name_line + "\n")
        fragments.extend(entry.after_name_lines)

    fragments.extend(parsed.suffix_lines)

    target_path.write_text("".join(fragments), encoding="utf-8")


def print_base(base_map: Optional[CommentedMap], threshold: float) -> None:
    """Display the extracted base element on stdout."""

    if not base_map:
        print(f"No base attributes met the threshold ({threshold:.2f}).")
        return

    container = CommentedMap()
    container["base"] = base_map
    print(f"Extracted base element (threshold={threshold:.2f}):")
    sys.stdout.write(dump_map(container))


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("yaml_file", type=Path, help="Path to the footprint YAML file")
    parser.add_argument(
        "-t",
        "--threshold",
        type=float,
        default=0.9,
        help="Fraction of entries a key must appear in (default: 0.9)",
    )
    parser.add_argument(
        "-r",
        "--round",
        type=int,
        default=None,
        help=(
            "Round all numeric values to the given number of decimal "
            "places before writing output."
        ),
    )
    parser.add_argument(
        "-o",
        "--output",
        type=Path,
        help=(
            "Output path for rewritten YAML file (default: overwrite the input file, "
            "requires --overwrite)."
        ),
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help=(
            "Allow overwriting the original YAML file; only effective when "
            "--output is set to the original file path."
        ),
    )
    parser.add_argument(
        "--anchor-name",
        default="base",
        help="Anchor name used for the shared base mapping (default: base)",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    if not 0 < args.threshold <= 1:
        raise SystemExit("Threshold must be within (0, 1].")

    text = args.yaml_file.read_text(encoding="utf-8")
    parsed = parse_file(text, args.anchor_name)
    parsed = strip_generated_base(parsed, args.anchor_name)
    ensure_base_inheritance(parsed)
    if not parsed.entries:
        raise SystemExit("No entries were found in the provided file.")

    common_map = collect_common(parsed.entries, args.threshold, parsed.existing_base)

    if parsed.existing_base:
        base_for_write = copy.deepcopy(parsed.existing_base)
        if common_map:
            additional = build_overrides(common_map, base_for_write)
            if additional:
                base_for_write = merged_with_base(base_for_write, additional)
    else:
        base_for_write = copy.deepcopy(common_map) if common_map else None

    print_base(base_for_write, args.threshold)

    target = args.output or args.yaml_file
    write_requires_overwrite = target.resolve() == args.yaml_file.resolve()
    if write_requires_overwrite and not args.overwrite:
        warning = (
            "--overwrite was not provided; the input file was left unchanged and no output was written."
        )
        print(_colorize(warning), file=sys.stderr)
        return

    rewrite(
        parsed,
        base_for_write if base_for_write else None,
        args.threshold,
        args.anchor_name,
        args.round,
        target,
    )


if __name__ == "__main__":
    main()

Check out similar posts by category: Python