Python script to extract common elements from YAML entries into a derivable base element
This script was developed for optimizing kicad-footprint-generator YAML definitions.
It scans a footprint YAML file, finds values that are identical across most entries, and moves them into a single shared base element. Each individual footprint then only keeps the parameters that actually differ from the base, which keeps the file shorter and reduces duplication.
For example, given a minimal input like this:
original.yaml
Vertical:
pad_width: 1.0
pad_height: 2.0
solder_mask_margin: 0.05
Horizontal:
pad_width: 1.0
pad_height: 2.0
solder_mask_margin: 0.05
special_option: truethe script can extract the common values into a base and rewrite the file roughly as:
processed.yaml
base: &base
pad_width: 1.0
pad_height: 2.0
solder_mask_margin: 0.05
Vertical:
<<: *base
Horizontal:
<<: *base
special_option: trueThis is essentially what it does for the much larger KiCad generator YAMLs.
yaml_extract_common_base.py
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2025 Uli Köhler <[email protected]>
# SPDX-License-Identifier: CC0-1.0
"""Extract and consolidate common values from KiCad generator YAML files."""
from __future__ import annotations
import argparse
import copy
import importlib
import io
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
try:
_ruamel_yaml = importlib.import_module("ruamel.yaml")
_ruamel_comments = importlib.import_module("ruamel.yaml.comments")
except ImportError as exc: # pragma: no cover - dependency missing at runtime
raise SystemExit(
"The ruamel.yaml package is required. Install it via 'pip install ruamel.yaml'."
) from exc
YAML = _ruamel_yaml.YAML
CommentedMap = _ruamel_comments.CommentedMap
CommentedSeq = _ruamel_comments.CommentedSeq
BASE_COMMENT_MARKER = "Base configuration extracted by YAMLCommon"
RT_YAML = YAML(typ="rt")
RT_YAML.preserve_quotes = True
RT_YAML.allow_duplicate_keys = True
RT_YAML.width = 4096
RT_YAML.indent(mapping=2, sequence=4, offset=2)
@dataclass
class Entry:
"""Represents a single footprint entry in the custom YAML file."""
leading_lines: List[str]
separator_lines: List[str]
name_line: Optional[str]
after_name_lines: List[str]
root_key: str
inner_map: CommentedMap
outer_map: CommentedMap
inherits_base: bool = False
defines_anchor: bool = False
@dataclass
class ParsedFile:
"""Holds the structural pieces of the parsed file."""
prefix_lines: List[str]
entries: List[Entry]
suffix_lines: List[str]
existing_base: Optional[CommentedMap] = None
def is_block_start(line: str) -> bool:
"""Return True if *line* marks the start of a footprint block."""
stripped = line.strip()
if not stripped or line.startswith(" "):
return False
if ":" not in stripped:
return False
key = stripped.split(":", 1)[0].strip()
return bool(key)
def is_name_line(line: str) -> bool:
"""Heuristic that recognises the footprint name line."""
stripped = line.strip()
if not stripped or ":" in stripped:
return False
if stripped.startswith("#") or stripped.startswith("-"):
return False
return line.startswith(" ")
def parse_file(text: str, anchor_name: str) -> ParsedFile:
"""Parse the non-standard YAML file into structured entries."""
lines = text.splitlines(keepends=True)
idx = 0
prefix_lines: List[str] = []
suffix_lines: List[str] = []
entries: List[Entry] = []
first_block = True
while True:
interim: List[str] = []
while idx < len(lines) and not is_block_start(lines[idx]):
interim.append(lines[idx])
idx += 1
if idx >= len(lines):
if first_block:
prefix_lines = interim
else:
suffix_lines = interim
break
if first_block:
prefix_lines = interim
leading_lines: List[str] = []
first_block = False
else:
leading_lines = interim
block_lines: List[str] = []
while idx < len(lines) and lines[idx].strip() != "":
block_lines.append(lines[idx])
idx += 1
defines_anchor = any(f"&{anchor_name}" in line for line in block_lines)
block_lines, inherits_base = strip_merge_marker(block_lines, anchor_name)
block_text = "".join(block_lines)
if not block_text.strip():
raise ValueError("Encountered empty block; the file format might be unsupported.")
separator_lines: List[str] = []
while idx < len(lines) and lines[idx].strip() == "":
separator_lines.append(lines[idx])
idx += 1
while idx < len(lines) and lines[idx].lstrip().startswith("#"):
separator_lines.append(lines[idx])
idx += 1
name_line: Optional[str] = None
if idx < len(lines) and is_name_line(lines[idx]):
name_line = lines[idx]
idx += 1
after_name_lines: List[str] = []
while idx < len(lines) and not is_block_start(lines[idx]):
after_name_lines.append(lines[idx])
idx += 1
data = RT_YAML.load(block_text)
if not isinstance(data, CommentedMap) or len(data) != 1:
raise ValueError("Each block must be a mapping with a single root key (e.g. Vertical).")
root_key = next(iter(data))
inner_map = data[root_key]
if not isinstance(inner_map, CommentedMap):
raise ValueError("Top-level entries must map to a dictionary of parameters.")
entries.append(
Entry(
leading_lines=leading_lines,
separator_lines=separator_lines,
name_line=name_line,
after_name_lines=after_name_lines,
root_key=root_key,
inner_map=inner_map,
outer_map=data,
inherits_base=inherits_base,
defines_anchor=defines_anchor,
)
)
return ParsedFile(prefix_lines=prefix_lines, entries=entries, suffix_lines=suffix_lines)
def strip_merge_marker(block_lines: List[str], anchor_name: str) -> tuple[List[str], bool]:
"""Remove merge references to the shared anchor from a YAML block."""
filtered: List[str] = []
removed = False
target = f"<<: *{anchor_name}"
for line in block_lines:
if line.strip() == target:
removed = True
continue
filtered.append(line)
return filtered, removed
def strip_generated_base(parsed: ParsedFile, anchor_name: str) -> ParsedFile:
"""Remove any previously generated base block and associated comment."""
clean_prefix = [line for line in parsed.prefix_lines if BASE_COMMENT_MARKER not in line]
filtered_entries: List[Entry] = []
base_map: Optional[CommentedMap] = parsed.existing_base
for entry in parsed.entries:
if entry.defines_anchor:
base_map = entry.inner_map
continue
filtered_entries.append(entry)
return ParsedFile(
prefix_lines=clean_prefix,
entries=filtered_entries,
suffix_lines=parsed.suffix_lines,
existing_base=base_map,
)
def ensure_base_inheritance(parsed: ParsedFile) -> None:
"""Mark entries as inheriting the base if they lack base keys."""
if not parsed.existing_base:
return
for entry in parsed.entries:
if entry.inherits_base:
continue
if not entry_supports_base(entry.inner_map, parsed.existing_base):
entry.inherits_base = True
def normalize(value):
"""Convert ruamel nodes to plain Python objects for equality checks."""
if isinstance(value, CommentedMap):
return {k: normalize(v) for k, v in value.items()}
if isinstance(value, CommentedSeq):
return [normalize(v) for v in value]
return value
def merged_with_base(base_map: CommentedMap, overrides: CommentedMap) -> CommentedMap:
"""Return a deep copy of *base_map* updated with *overrides*."""
result = copy.deepcopy(base_map)
for key, value in overrides.items():
if (
key in result
and isinstance(result[key], CommentedMap)
and isinstance(value, CommentedMap)
):
result[key] = merged_with_base(result[key], value)
else:
result[key] = copy.deepcopy(value)
return result
def values_equal(lhs, rhs) -> bool:
"""Deep equality that ignores ruamel-specific metadata."""
return normalize(lhs) == normalize(rhs)
def _round_number(value, digits: Optional[int]):
"""Round numeric *value* to *digits* decimal places if requested.
If *digits* is None, the value is returned unchanged. Non-numeric
values are also returned unchanged.
"""
if digits is None:
return value
try:
# Handle ints/floats while leaving other types untouched
if isinstance(value, (int, float)):
return round(value, digits)
except TypeError:
pass
return value
def _round_node(node, digits: Optional[int]):
"""Recursively round all numeric scalars within *node* in-place."""
if digits is None:
return node
if isinstance(node, CommentedMap):
for key, val in list(node.items()):
node[key] = _round_node(val, digits)
return node
if isinstance(node, CommentedSeq):
for idx, val in enumerate(list(node)):
node[idx] = _round_node(val, digits)
return node
return _round_number(node, digits)
def collect_common(
entries: List[Entry],
threshold: float,
fallback_base: Optional[CommentedMap] = None,
) -> CommentedMap:
"""Return a mapping with keys that are common across entries."""
total = len(entries)
result = CommentedMap()
if total == 0:
return result
occurrences: dict[str, List[object]] = {}
for entry in entries:
items = entry.inner_map.items()
if entry.inherits_base and fallback_base is not None:
combined = merged_with_base(fallback_base, entry.inner_map)
items = combined.items()
for key, value in items:
occurrences.setdefault(key, []).append(value)
for key, values in occurrences.items():
ratio = len(values) / total
if ratio < threshold:
continue
first = values[0]
if all(values_equal(first, other) for other in values[1:]):
result[key] = copy.deepcopy(first)
return result
def entry_supports_base(entry_map: CommentedMap, base_map: CommentedMap) -> bool:
"""Check if *entry_map* contains all keys found in *base_map*."""
for key, base_value in base_map.items():
if key not in entry_map:
return False
entry_value = entry_map[key]
if isinstance(base_value, CommentedMap) and isinstance(entry_value, CommentedMap):
if not entry_supports_base(entry_value, base_value):
return False
elif isinstance(base_value, CommentedSeq) and isinstance(entry_value, CommentedSeq):
if not values_equal(entry_value, base_value):
return False
else:
if not values_equal(entry_value, base_value):
return False
return True
def build_overrides(entry_map: CommentedMap, base_map: CommentedMap) -> CommentedMap:
"""Return the subset of *entry_map* that differs from *base_map*."""
overrides = CommentedMap()
for key, entry_value in entry_map.items():
if key not in base_map:
overrides[key] = copy.deepcopy(entry_value)
continue
base_value = base_map[key]
if isinstance(entry_value, CommentedMap) and isinstance(base_value, CommentedMap):
child = build_overrides(entry_value, base_value)
if child:
overrides[key] = child
continue
if isinstance(entry_value, CommentedSeq) and isinstance(base_value, CommentedSeq):
if not values_equal(entry_value, base_value):
overrides[key] = copy.deepcopy(entry_value)
continue
if not values_equal(entry_value, base_value):
overrides[key] = copy.deepcopy(entry_value)
return overrides
def dump_map(node: CommentedMap) -> str:
"""Render *node* back to YAML text."""
buffer = io.StringIO()
RT_YAML.dump(node, buffer)
return buffer.getvalue()
def _supports_color(stream) -> bool:
"""Return True if *stream* likely supports ANSI color codes."""
return bool(getattr(stream, "isatty", lambda: False)()) and os.environ.get("TERM") not in {None, "", "dumb"}
def _colorize(text: str, color_code: str = "31") -> str:
"""Wrap *text* in ANSI color codes if the current stderr supports it."""
if _supports_color(sys.stderr):
return f"\033[{color_code}m{text}\033[0m"
return text
def inject_merge_line(rendered: str, anchor_name: str, indent: str = " ") -> str:
"""Insert a YAML merge reference right after the root key."""
lines = rendered.splitlines(keepends=True)
if not lines:
return f"{indent}<<: *{anchor_name}\n"
if not lines[0].endswith("\n"):
lines[0] += "\n"
merge_line = f"{indent}<<: *{anchor_name}\n"
lines.insert(1, merge_line)
return "".join(lines)
def render_with_merge(root_key: str, overrides: CommentedMap, anchor_name: str) -> str:
"""Render an entry map preceded by a merge reference."""
if overrides:
container = CommentedMap()
container[root_key] = overrides
return inject_merge_line(dump_map(container), anchor_name)
return f"{root_key}:\n <<: *{anchor_name}\n"
def rewrite(
parsed: ParsedFile,
base_map: Optional[CommentedMap],
threshold: float,
anchor_name: str,
round_digits: Optional[int],
target_path: Path,
) -> None:
"""Rewrite the YAML file with a base block followed by per-entry overrides."""
fragments: List[str] = []
fragments.extend(parsed.prefix_lines)
base_inner = None
if base_map and len(base_map):
base_inner = copy.deepcopy(base_map)
# Apply rounding to the shared base before anchoring, if requested.
if round_digits is not None:
_round_node(base_inner, round_digits)
base_inner.yaml_set_anchor(anchor_name, always_dump=True)
# The base element must always be named "base" in the
# generated YAML, regardless of the original root key.
base_container = CommentedMap()
base_container["base"] = base_inner
base_container.yaml_set_start_comment(
f"Base configuration extracted by YAMLCommon (threshold={threshold:.2f})",
indent=0,
)
fragments.append(dump_map(base_container))
fragments.append("\n")
for entry in parsed.entries:
fragments.extend(entry.leading_lines)
# If we have a generated base and the entry either explicitly supports
# that base or was previously marked as inheriting from the base
# (e.g. because it was missing base keys), always render the entry
# using a merge reference. This ensures every entry that removed
# values relative to the base is still declared as deriving from it.
entry_map_for_write = copy.deepcopy(entry.outer_map)
if base_inner and (entry_supports_base(entry.inner_map, base_inner) or entry.inherits_base):
# Build overrides relative to the (already rounded) base.
inner_copy = copy.deepcopy(entry.inner_map)
if round_digits is not None:
_round_node(inner_copy, round_digits)
overrides = build_overrides(inner_copy, base_inner)
# If there are no overrides (i.e. the entry is identical to the
# base), still emit a merge so the entry remains present.
fragments.append(render_with_merge(entry.root_key, overrides, anchor_name))
else:
# No base or unsupported by base: round the whole entry mapping.
if round_digits is not None:
_round_node(entry_map_for_write, round_digits)
fragments.append(dump_map(entry_map_for_write))
fragments.extend(entry.separator_lines)
if entry.name_line:
fragments.append(entry.name_line if entry.name_line.endswith("\n") else entry.name_line + "\n")
fragments.extend(entry.after_name_lines)
fragments.extend(parsed.suffix_lines)
target_path.write_text("".join(fragments), encoding="utf-8")
def print_base(base_map: Optional[CommentedMap], threshold: float) -> None:
"""Display the extracted base element on stdout."""
if not base_map:
print(f"No base attributes met the threshold ({threshold:.2f}).")
return
container = CommentedMap()
container["base"] = base_map
print(f"Extracted base element (threshold={threshold:.2f}):")
sys.stdout.write(dump_map(container))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("yaml_file", type=Path, help="Path to the footprint YAML file")
parser.add_argument(
"-t",
"--threshold",
type=float,
default=0.9,
help="Fraction of entries a key must appear in (default: 0.9)",
)
parser.add_argument(
"-r",
"--round",
type=int,
default=None,
help=(
"Round all numeric values to the given number of decimal "
"places before writing output."
),
)
parser.add_argument(
"-o",
"--output",
type=Path,
help=(
"Output path for rewritten YAML file (default: overwrite the input file, "
"requires --overwrite)."
),
)
parser.add_argument(
"--overwrite",
action="store_true",
help=(
"Allow overwriting the original YAML file; only effective when "
"--output is set to the original file path."
),
)
parser.add_argument(
"--anchor-name",
default="base",
help="Anchor name used for the shared base mapping (default: base)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not 0 < args.threshold <= 1:
raise SystemExit("Threshold must be within (0, 1].")
text = args.yaml_file.read_text(encoding="utf-8")
parsed = parse_file(text, args.anchor_name)
parsed = strip_generated_base(parsed, args.anchor_name)
ensure_base_inheritance(parsed)
if not parsed.entries:
raise SystemExit("No entries were found in the provided file.")
common_map = collect_common(parsed.entries, args.threshold, parsed.existing_base)
if parsed.existing_base:
base_for_write = copy.deepcopy(parsed.existing_base)
if common_map:
additional = build_overrides(common_map, base_for_write)
if additional:
base_for_write = merged_with_base(base_for_write, additional)
else:
base_for_write = copy.deepcopy(common_map) if common_map else None
print_base(base_for_write, args.threshold)
target = args.output or args.yaml_file
write_requires_overwrite = target.resolve() == args.yaml_file.resolve()
if write_requires_overwrite and not args.overwrite:
warning = (
"--overwrite was not provided; the input file was left unchanged and no output was written."
)
print(_colorize(warning), file=sys.stderr)
return
rewrite(
parsed,
base_for_write if base_for_write else None,
args.threshold,
args.anchor_name,
args.round,
target,
)
if __name__ == "__main__":
main()Check out similar posts by category:
Python
If this post helped you, please consider buying me a coffee or donating via PayPal to support research & publishing of new posts on TechOverflow