update to SCC algorithm, tidy internal names

This commit is contained in:
Thomas (Tom) C. Gorordo 2026-05-25 07:08:08 -07:00
parent fcf2505820
commit 674fdc1fe9
Signed by: tgorordo
GPG key ID: 0CBED22BB0D94490
7 changed files with 42 additions and 45 deletions

View file

@ -7,9 +7,6 @@ run *args:
marimo:
uv run marimo --edit
example:
uv run python src/main.py test/test_ballot.csv
format:
uv run ruff format src test

View file

@ -8,7 +8,6 @@ authors = [
]
requires-python = ">=3.13"
dependencies = [
"numpy>=2.4.6",
"polars>=1.40.1",
"rustworkx>=0.17.1",
]

View file

@ -96,9 +96,7 @@ nbformat==5.10.4
nodeenv==1.10.0
# via pyright
numpy==2.4.6
# via
# smithy (pyproject.toml)
# rustworkx
# via rustworkx
openai==2.37.0
# via pydantic-ai-slim
opentelemetry-api==1.42.0

View file

@ -1,16 +1,17 @@
import polars as pl
import rustworkx as rwx
from itertools import combinations
from .rcv import pairmaj_from_rcv
from .rcv import pmg_from_rcv
def smith_set_brutefrom_pairmaj(pairmaj_graph: dict[str, set[str]]) -> list:
def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]:
"""
Brute-force the Smith set from a pairwise majority winner graph.
Find the Smith set from a pairwise majority graph.
parameters
---
pairmaj_graph: dict[str, set[str]]
pmg: rwx.PyDiGraph
A graph whose nodes correspond to candidates and (directed) edges show
which candidates they beat pairwise.
@ -22,28 +23,18 @@ def smith_set_brutefrom_pairmaj(pairmaj_graph: dict[str, set[str]]) -> list:
(single Majority winner), the Smith set will contain that single candidate.
"""
candidates = set(pairmaj_graph.keys())
size = len(candidates)
sccs = rwx.strongly_connected_components(pmg)
for size in range(1, len(candidates) + 1):
for sub in combinations(candidates, size):
subset = set(sub)
out = set(candidates) - subset
cg = rwx.condensation(pmg, sccs)
dom = True
src_sccs = [nd for nd in cg.node_indices() if cg.in_degree(nd) == 0]
for member in subset:
if not out.issubset(pairmaj_graph[member]):
dom = False
break
smith_set = sorted([c for scc in src_sccs for c in cg[scc]])
if dom:
return sorted(subset)
return []
return smith_set
def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
def smith_set_from_rcv(ballots: pl.DataFrame) -> list:
"""
Compute the Smith set from a Ranked-Choice ballot.
@ -52,7 +43,7 @@ def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
parameters
---
df : pl.DataFrame
ballots : pl.DataFrame
A Polars DataFrame representing ballots. Each column is a candidate and each
row is is a voter's ranking of the candidates. Lower numbers indicate higher
preference (1 = top-choice).
@ -66,7 +57,8 @@ def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
"""
return smith_set_brutefrom_pairmaj(pairmaj_from_rcv(rcv_ballots))
# return smith_set_brutefrom_pairmaj(pairmaj_from_rcv(rcv_ballots))
return ss_from_pmg(pmg_from_rcv(ballots))
def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list:

View file

@ -1,8 +1,9 @@
import polars as pl
import rustworkx as rwx
from itertools import combinations
def pairmaj_from_rcv(rcv_ballots: pl.DataFrame) -> dict[str, set[str]]:
def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph:
"""
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
@ -15,27 +16,38 @@ def pairmaj_from_rcv(rcv_ballots: pl.DataFrame) -> dict[str, set[str]]:
returns
---
pairmaj_graph: dict[str, set[str]]
nodes: dict[str, int]
A dictionary of candidate names to associated node ids.
pwm_graph: rwx.PyDiGraph
A pairwise majority winner graph whose nodes correspond to candidates and
(directed) edges show which candidates they beat pairwise.
"""
candidates = rcv_ballots.columns
candidates = ballots.columns
pairmaj_graph: dict[str, set[str]] = {c: set() for c in candidates}
pmg = rwx.PyDiGraph()
nodes = {c: pmg.add_node(c) for c in candidates}
for a, b in combinations(candidates, 2):
result = rcv_ballots.select(
exprs = []
pairs = list(combinations(candidates, 2))
for a, b in pairs:
exprs.extend(
[
(pl.col(a) < pl.col(b)).sum().alias("a_wins"),
(pl.col(b) < pl.col(a)).sum().alias("b_wins"),
(pl.col(a) < pl.col(b)).sum().alias(f"{a}>{b}"),
(pl.col(b) < pl.col(a)).sum().alias(f"{b}>{a}"),
]
).row(0)
)
a_wins, b_wins = result
results = ballots.select(exprs).row(0, named=True)
for a, b in pairs:
a_wins = results[f"{a}>{b}"]
b_wins = results[f"{b}>{a}"]
if a_wins > b_wins:
pairmaj_graph[a].add(b)
pmg.add_edge(nodes[a], nodes[b], a_wins - b_wins)
elif b_wins > a_wins:
pairmaj_graph[b].add(a)
pmg.add_edge(nodes[b], nodes[a], b_wins - a_wins)
return pairmaj_graph
return pmg

View file

@ -3,7 +3,8 @@
# dependencies = [
# "click>=8.4.1",
# "rich>=15.0.0",
# "polars>=1.40.1"
# "polars>=1.40.1",
# "rustworkx>=0.17.1"
# ]
# ///
import sys, io

2
uv.lock generated
View file

@ -1556,7 +1556,6 @@ name = "smithy"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "numpy" },
{ name = "polars" },
{ name = "rustworkx" },
]
@ -1584,7 +1583,6 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "click", marker = "extra == 'cli'", specifier = ">=8.4.0" },
{ name = "numpy", specifier = ">=2.4.6" },
{ name = "polars", specifier = ">=1.40.1" },
{ name = "rich", marker = "extra == 'cli'", specifier = ">=15.0.0" },
{ name = "rustworkx", specifier = ">=0.17.1" },