add IRV smith-set resolution method

This commit is contained in:
Thomas (Tom) C. Gorordo 2026-06-05 14:39:31 -07:00
parent fc0569e252
commit e0d0134fdd
Signed by: tgorordo
GPG key ID: 0CBED22BB0D94490
4 changed files with 153 additions and 19 deletions

View file

@ -3,6 +3,7 @@ import rustworkx as rwx
from itertools import combinations from itertools import combinations
from .rcv import pmg_from_rcv from .rcv import pmg_from_rcv
from .irv import irv_from_rcv
def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]: def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]:
@ -67,3 +68,12 @@ def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list:
raise NotImplementedError( raise NotImplementedError(
f"`smith_set` ballotkind={ballotkind} is not implemented." f"`smith_set` ballotkind={ballotkind} is not implemented."
) )
def irv_set(df: pl.DataFrame, ballotkind="rcv") -> list:
if ballotkind == "rcv":
return irv_from_rcv(df)
else:
raise NotImplementedError(
f"`irv_set` ballotkind={ballotkind} is not implemented."
)

102
src/smithy/irv.py Normal file
View file

@ -0,0 +1,102 @@
import polars as pl
import numpy as np
def irv_from_rcv(ballots: pl.DataFrame, method: str = "bigslow") -> list[str]:
"""
Compute the set of all-paths IRV winners from an RCV ballot.
parameters
---
ballots: pl.DataFrame
An RCV table of ballots.
method: str
Either "bigslow" or "smallfast" for selecting an internal method for counting
first-choices during IRV rounds. Defaults to "bigslow" but you can use "smallfast"
so long as the table of ballots is expected to fit in a reasonable numpy array (after compression).
returns
---
winners: list[srt]
A lexicographically sorted list of IRV winners. If a candidate wins every elimination path
then this set will contain only one entry, otherwise it will contain all candidates that win
at least one IRV elimination path.
"""
compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"})
return sorted(_irv_winners(compressed, method=method))
def _fst_counts_bigslow(compressed: pl.DataFrame) -> pl.DataFrame:
surviving = [c for c in compressed.columns if c != "count"]
fstcexpr = (
pl.concat_list([pl.col(c) for c in surviving])
.list.arg_min().map_elements(lambda i: surviving[i], return_dtype=pl.String).alias("first_choice")
)
tally = (
compressed.with_columns(fstcexpr)
.group_by("first_choice")
.agg(pl.col("count").sum())
.filter(pl.col("first_choice").is_not_null())
)
return tally
def _fst_counts_smallfast(compressed: pl.DataFrame) -> pl.DataFrame:
surviving = [c for c in compressed.columns if c != "count"]
a = compressed.select(surviving).to_numpy()
cs = compressed["count"].to_numpy()
fstc_idxs = np.argmin(a, axis=1)
tally = {c: 0 for c in surviving}
for i, c in zip(fstc_idxs, cs):
tally[surviving[i]] += int(c)
return pl.DataFrame(
{"first_choice": surviving, "count": [tally[c] for c in surviving]}
)
def _irv_round(compressed: pl.DataFrame, method="bigslow"):
if method == "bigslow":
count_fn = _fst_counts_bigslow
elif method == "smallfast":
count_fn = _fst_counts_smallfast
else:
raise NotImplementedError(
f"Error: _fst_counts method={method} not implemented."
)
tally = count_fn(compressed)
eliminate = tally.filter(pl.col("count") == pl.col("count").min())[
"first_choice"
].to_list()
for e in eliminate:
surviving = [c for c in compressed.columns if c not in ("count", e)]
yield (
compressed.select(surviving + ["count"])
.group_by(surviving)
.agg(pl.col("count").sum())
)
def _irv_winners(compressed, method="bigslow"):
surviving = [c for c in compressed.columns if c != "count"]
if len(surviving) == 1:
return set(surviving)
winners = set()
for branch in _irv_round(compressed, method=method):
winners |= _irv_winners(branch, method=method)
return winners

View file

@ -3,7 +3,7 @@ import rustworkx as rwx
from itertools import combinations from itertools import combinations
def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph: def pmg_from_rcv_bigslow(ballots: pl.DataFrame) -> rwx.PyDiGraph:
""" """
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots. Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
@ -34,10 +34,20 @@ def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph:
pairs = list(combinations(candidates, 2)) pairs = list(combinations(candidates, 2))
for a, b in pairs: for a, b in pairs:
exprs.extend([ exprs.extend(
pl.when(pl.col(a) < pl.col(b)).then(pl.col("count")).otherwise(0).sum().alias(f"{a}>{b}"), [
pl.when(pl.col(b) < pl.col(a)).then(pl.col("count")).otherwise(0).sum().alias(f"{b}>{a}") pl.when(pl.col(a) < pl.col(b))
]) .then(pl.col("count"))
.otherwise(0)
.sum()
.alias(f"{a}>{b}"),
pl.when(pl.col(b) < pl.col(a))
.then(pl.col("count"))
.otherwise(0)
.sum()
.alias(f"{b}>{a}"),
]
)
results = compressed.select(exprs).row(0, named=True) results = compressed.select(exprs).row(0, named=True)
@ -52,7 +62,8 @@ def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph:
return pmg return pmg
def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
def pmg_from_rcv_smallfast(ballots: pl.DataFrame) -> rwx.PyDiGraph:
""" """
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots. Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
@ -79,17 +90,17 @@ def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"}) compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"})
counts = compressed["count"].to_numpy() counts = compressed["count"].to_numpy()
arr = compressed.drop("count").to_numpy() arr = compressed.drop("count").to_numpy()
results = ((arr[:, :, None] < arr[:, None, :]) * counts[:, None, None]).sum(axis=0) results = ((arr[:, :, None] < arr[:, None, :]) * counts[:, None, None]).sum(axis=0)
for i, a in enumerate(candidates): for i, a in enumerate(candidates):
for j in range(i + 1, len(candidates)): for j in range(i + 1, len(candidates)):
b = candidates[j] b = candidates[j]
a_wins = results[i, j] a_wins = results[i, j]
b_wins = results[j, i] b_wins = results[j, i]
if a_wins > b_wins: if a_wins > b_wins:
pmg.add_edge(nodes[a], nodes[b], int(a_wins - b_wins)) pmg.add_edge(nodes[a], nodes[b], int(a_wins - b_wins))
elif b_wins > a_wins: elif b_wins > a_wins:
@ -97,10 +108,11 @@ def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
return pmg return pmg
def pmg_from_rcv(ballots: pl.DataFrame, method="numpy") -> rwx.PyDiGraph:
if method == "polars": def pmg_from_rcv(ballots: pl.DataFrame, method="bigslow") -> rwx.PyDiGraph:
return pmg_from_rcv_polars(ballots) if method == "bigslow":
elif method == "numpy": return pmg_from_rcv_bigslow(ballots)
return pmg_from_rcv_numpy(ballots) elif method == "smallfast":
return pmg_from_rcv_smallfast(ballots)
else: else:
raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.") raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.")

View file

@ -15,11 +15,16 @@ from rich.table import Table
from rich.panel import Panel from rich.panel import Panel
import polars as pl import polars as pl
from smithy import smith_set from smithy import smith_set, irv_set
@click.command() @click.command()
@click.argument("ballots", type=click.Path(exists=True, dir_okay=False)) @click.argument("ballots", type=click.Path(exists=True, dir_okay=False))
@click.option(
"--try-resolve-irv",
is_flag=True,
help="Try to reduce or resolve the Smith set by running all-paths IRV on the set.",
)
@click.option( @click.option(
"--show-ballots", "--show-ballots",
"-b", "-b",
@ -27,7 +32,7 @@ from smithy import smith_set
help="Show relevant ballots (after selections).", help="Show relevant ballots (after selections).",
) )
@click.option("--pretty", "-p", is_flag=True, help="Pretty-print output.") @click.option("--pretty", "-p", is_flag=True, help="Pretty-print output.")
def cli(ballots: str, show_ballots=False, pretty=False) -> None: def cli(ballots: str, try_resolve_irv=False, show_ballots=False, pretty=False) -> None:
""" """
Compute the Smith set from a box of ranked-choice ballots -- .csv or .xls(x). Compute the Smith set from a box of ranked-choice ballots -- .csv or .xls(x).
@ -56,7 +61,7 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
df = df.with_columns( df = df.with_columns(
[ [
pl.col(c) pl.col(c)
.cast(pl.Utf8) .cast(pl.String)
.str.strip_chars() .str.strip_chars()
.cast(pl.Int64, strict=False) .cast(pl.Int64, strict=False)
.fill_null(0) .fill_null(0)
@ -66,6 +71,9 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
# Compute Smith set # Compute Smith set
smiths = smith_set(df) smiths = smith_set(df)
if len(smiths) > 1 and try_resolve_irv:
irv_ballots = df.select(smiths)
smiths = irv_set(irv_ballots)
if show_ballots and pretty: if show_ballots and pretty:
preview = Table(title="Ballot Box") preview = Table(title="Ballot Box")
@ -88,7 +96,9 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
console.print( console.print(
Panel.fit( Panel.fit(
"\n".join(f"{c}" for c in smiths), "\n".join(f"{c}" for c in smiths),
title="Resulting Smith Set", title="Resulting IRV-resolved Smith Set"
if (try_resolve_irv)
else "Resulting Smith Set",
border_style="green", border_style="green",
) )
) )