From e0d0134fddb3c27fa40352085056ad8ed835d9cc Mon Sep 17 00:00:00 2001 From: "Thomas (Tom) C. Gorordo" Date: Fri, 5 Jun 2026 14:39:31 -0700 Subject: [PATCH] add IRV smith-set resolution method --- src/smithy/__init__.py | 10 ++++ src/smithy/irv.py | 102 +++++++++++++++++++++++++++++++++++++++++ src/smithy/rcv.py | 42 +++++++++++------ src/smithycmd.py | 18 ++++++-- 4 files changed, 153 insertions(+), 19 deletions(-) create mode 100644 src/smithy/irv.py diff --git a/src/smithy/__init__.py b/src/smithy/__init__.py index 61d9a84..ec9ec9f 100644 --- a/src/smithy/__init__.py +++ b/src/smithy/__init__.py @@ -3,6 +3,7 @@ import rustworkx as rwx from itertools import combinations from .rcv import pmg_from_rcv +from .irv import irv_from_rcv def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]: @@ -67,3 +68,12 @@ def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list: raise NotImplementedError( f"`smith_set` ballotkind={ballotkind} is not implemented." ) + + +def irv_set(df: pl.DataFrame, ballotkind="rcv") -> list: + if ballotkind == "rcv": + return irv_from_rcv(df) + else: + raise NotImplementedError( + f"`irv_set` ballotkind={ballotkind} is not implemented." + ) diff --git a/src/smithy/irv.py b/src/smithy/irv.py new file mode 100644 index 0000000..eb89612 --- /dev/null +++ b/src/smithy/irv.py @@ -0,0 +1,102 @@ +import polars as pl +import numpy as np + + +def irv_from_rcv(ballots: pl.DataFrame, method: str = "bigslow") -> list[str]: + """ + Compute the set of all-paths IRV winners from an RCV ballot. + + parameters + --- + ballots: pl.DataFrame + An RCV table of ballots. + + method: str + Either "bigslow" or "smallfast" for selecting an internal method for counting + first-choices during IRV rounds. Defaults to "bigslow" but you can use "smallfast" + so long as the table of ballots is expected to fit in a reasonable numpy array (after compression). + + returns + --- + winners: list[srt] + A lexicographically sorted list of IRV winners. If a candidate wins every elimination path + then this set will contain only one entry, otherwise it will contain all candidates that win + at least one IRV elimination path. + """ + compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"}) + return sorted(_irv_winners(compressed, method=method)) + + +def _fst_counts_bigslow(compressed: pl.DataFrame) -> pl.DataFrame: + + surviving = [c for c in compressed.columns if c != "count"] + + fstcexpr = ( + pl.concat_list([pl.col(c) for c in surviving]) + .list.arg_min().map_elements(lambda i: surviving[i], return_dtype=pl.String).alias("first_choice") + ) + + tally = ( + compressed.with_columns(fstcexpr) + .group_by("first_choice") + .agg(pl.col("count").sum()) + .filter(pl.col("first_choice").is_not_null()) + ) + return tally + + +def _fst_counts_smallfast(compressed: pl.DataFrame) -> pl.DataFrame: + + surviving = [c for c in compressed.columns if c != "count"] + + a = compressed.select(surviving).to_numpy() + + cs = compressed["count"].to_numpy() + + fstc_idxs = np.argmin(a, axis=1) + + tally = {c: 0 for c in surviving} + for i, c in zip(fstc_idxs, cs): + tally[surviving[i]] += int(c) + + return pl.DataFrame( + {"first_choice": surviving, "count": [tally[c] for c in surviving]} + ) + + +def _irv_round(compressed: pl.DataFrame, method="bigslow"): + + if method == "bigslow": + count_fn = _fst_counts_bigslow + elif method == "smallfast": + count_fn = _fst_counts_smallfast + else: + raise NotImplementedError( + f"Error: _fst_counts method={method} not implemented." + ) + + tally = count_fn(compressed) + + eliminate = tally.filter(pl.col("count") == pl.col("count").min())[ + "first_choice" + ].to_list() + + for e in eliminate: + surviving = [c for c in compressed.columns if c not in ("count", e)] + yield ( + compressed.select(surviving + ["count"]) + .group_by(surviving) + .agg(pl.col("count").sum()) + ) + + +def _irv_winners(compressed, method="bigslow"): + + surviving = [c for c in compressed.columns if c != "count"] + if len(surviving) == 1: + return set(surviving) + + winners = set() + for branch in _irv_round(compressed, method=method): + winners |= _irv_winners(branch, method=method) + return winners diff --git a/src/smithy/rcv.py b/src/smithy/rcv.py index 99de429..4f51fe2 100644 --- a/src/smithy/rcv.py +++ b/src/smithy/rcv.py @@ -3,7 +3,7 @@ import rustworkx as rwx from itertools import combinations -def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph: +def pmg_from_rcv_bigslow(ballots: pl.DataFrame) -> rwx.PyDiGraph: """ Build a pairwise majority winner graph from a box of Ranked-Choice Ballots. @@ -34,10 +34,20 @@ def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph: pairs = list(combinations(candidates, 2)) for a, b in pairs: - exprs.extend([ - pl.when(pl.col(a) < pl.col(b)).then(pl.col("count")).otherwise(0).sum().alias(f"{a}>{b}"), - pl.when(pl.col(b) < pl.col(a)).then(pl.col("count")).otherwise(0).sum().alias(f"{b}>{a}") - ]) + exprs.extend( + [ + pl.when(pl.col(a) < pl.col(b)) + .then(pl.col("count")) + .otherwise(0) + .sum() + .alias(f"{a}>{b}"), + pl.when(pl.col(b) < pl.col(a)) + .then(pl.col("count")) + .otherwise(0) + .sum() + .alias(f"{b}>{a}"), + ] + ) results = compressed.select(exprs).row(0, named=True) @@ -52,7 +62,8 @@ def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph: return pmg -def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph: + +def pmg_from_rcv_smallfast(ballots: pl.DataFrame) -> rwx.PyDiGraph: """ Build a pairwise majority winner graph from a box of Ranked-Choice Ballots. @@ -79,17 +90,17 @@ def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph: compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"}) counts = compressed["count"].to_numpy() - + arr = compressed.drop("count").to_numpy() results = ((arr[:, :, None] < arr[:, None, :]) * counts[:, None, None]).sum(axis=0) for i, a in enumerate(candidates): for j in range(i + 1, len(candidates)): b = candidates[j] - + a_wins = results[i, j] b_wins = results[j, i] - + if a_wins > b_wins: pmg.add_edge(nodes[a], nodes[b], int(a_wins - b_wins)) elif b_wins > a_wins: @@ -97,10 +108,11 @@ def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph: return pmg -def pmg_from_rcv(ballots: pl.DataFrame, method="numpy") -> rwx.PyDiGraph: - if method == "polars": - return pmg_from_rcv_polars(ballots) - elif method == "numpy": - return pmg_from_rcv_numpy(ballots) + +def pmg_from_rcv(ballots: pl.DataFrame, method="bigslow") -> rwx.PyDiGraph: + if method == "bigslow": + return pmg_from_rcv_bigslow(ballots) + elif method == "smallfast": + return pmg_from_rcv_smallfast(ballots) else: - raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.") \ No newline at end of file + raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.") diff --git a/src/smithycmd.py b/src/smithycmd.py index 3b58971..d94521b 100644 --- a/src/smithycmd.py +++ b/src/smithycmd.py @@ -15,11 +15,16 @@ from rich.table import Table from rich.panel import Panel import polars as pl -from smithy import smith_set +from smithy import smith_set, irv_set @click.command() @click.argument("ballots", type=click.Path(exists=True, dir_okay=False)) +@click.option( + "--try-resolve-irv", + is_flag=True, + help="Try to reduce or resolve the Smith set by running all-paths IRV on the set.", +) @click.option( "--show-ballots", "-b", @@ -27,7 +32,7 @@ from smithy import smith_set help="Show relevant ballots (after selections).", ) @click.option("--pretty", "-p", is_flag=True, help="Pretty-print output.") -def cli(ballots: str, show_ballots=False, pretty=False) -> None: +def cli(ballots: str, try_resolve_irv=False, show_ballots=False, pretty=False) -> None: """ Compute the Smith set from a box of ranked-choice ballots -- .csv or .xls(x). @@ -56,7 +61,7 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None: df = df.with_columns( [ pl.col(c) - .cast(pl.Utf8) + .cast(pl.String) .str.strip_chars() .cast(pl.Int64, strict=False) .fill_null(0) @@ -66,6 +71,9 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None: # Compute Smith set smiths = smith_set(df) + if len(smiths) > 1 and try_resolve_irv: + irv_ballots = df.select(smiths) + smiths = irv_set(irv_ballots) if show_ballots and pretty: preview = Table(title="Ballot Box") @@ -88,7 +96,9 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None: console.print( Panel.fit( "\n".join(f"• {c}" for c in smiths), - title="Resulting Smith Set", + title="Resulting IRV-resolved Smith Set" + if (try_resolve_irv) + else "Resulting Smith Set", border_style="green", ) )