mirror of
https://github.com/tgorordo/smithy.git
synced 2026-06-05 16:22:15 -07:00
add IRV smith-set resolution method
This commit is contained in:
parent
fc0569e252
commit
e0d0134fdd
4 changed files with 153 additions and 19 deletions
|
|
@ -3,6 +3,7 @@ import rustworkx as rwx
|
|||
from itertools import combinations
|
||||
|
||||
from .rcv import pmg_from_rcv
|
||||
from .irv import irv_from_rcv
|
||||
|
||||
|
||||
def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]:
|
||||
|
|
@ -67,3 +68,12 @@ def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list:
|
|||
raise NotImplementedError(
|
||||
f"`smith_set` ballotkind={ballotkind} is not implemented."
|
||||
)
|
||||
|
||||
|
||||
def irv_set(df: pl.DataFrame, ballotkind="rcv") -> list:
|
||||
if ballotkind == "rcv":
|
||||
return irv_from_rcv(df)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"`irv_set` ballotkind={ballotkind} is not implemented."
|
||||
)
|
||||
|
|
|
|||
102
src/smithy/irv.py
Normal file
102
src/smithy/irv.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
|
||||
def irv_from_rcv(ballots: pl.DataFrame, method: str = "bigslow") -> list[str]:
|
||||
"""
|
||||
Compute the set of all-paths IRV winners from an RCV ballot.
|
||||
|
||||
parameters
|
||||
---
|
||||
ballots: pl.DataFrame
|
||||
An RCV table of ballots.
|
||||
|
||||
method: str
|
||||
Either "bigslow" or "smallfast" for selecting an internal method for counting
|
||||
first-choices during IRV rounds. Defaults to "bigslow" but you can use "smallfast"
|
||||
so long as the table of ballots is expected to fit in a reasonable numpy array (after compression).
|
||||
|
||||
returns
|
||||
---
|
||||
winners: list[srt]
|
||||
A lexicographically sorted list of IRV winners. If a candidate wins every elimination path
|
||||
then this set will contain only one entry, otherwise it will contain all candidates that win
|
||||
at least one IRV elimination path.
|
||||
"""
|
||||
compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"})
|
||||
return sorted(_irv_winners(compressed, method=method))
|
||||
|
||||
|
||||
def _fst_counts_bigslow(compressed: pl.DataFrame) -> pl.DataFrame:
|
||||
|
||||
surviving = [c for c in compressed.columns if c != "count"]
|
||||
|
||||
fstcexpr = (
|
||||
pl.concat_list([pl.col(c) for c in surviving])
|
||||
.list.arg_min().map_elements(lambda i: surviving[i], return_dtype=pl.String).alias("first_choice")
|
||||
)
|
||||
|
||||
tally = (
|
||||
compressed.with_columns(fstcexpr)
|
||||
.group_by("first_choice")
|
||||
.agg(pl.col("count").sum())
|
||||
.filter(pl.col("first_choice").is_not_null())
|
||||
)
|
||||
return tally
|
||||
|
||||
|
||||
def _fst_counts_smallfast(compressed: pl.DataFrame) -> pl.DataFrame:
|
||||
|
||||
surviving = [c for c in compressed.columns if c != "count"]
|
||||
|
||||
a = compressed.select(surviving).to_numpy()
|
||||
|
||||
cs = compressed["count"].to_numpy()
|
||||
|
||||
fstc_idxs = np.argmin(a, axis=1)
|
||||
|
||||
tally = {c: 0 for c in surviving}
|
||||
for i, c in zip(fstc_idxs, cs):
|
||||
tally[surviving[i]] += int(c)
|
||||
|
||||
return pl.DataFrame(
|
||||
{"first_choice": surviving, "count": [tally[c] for c in surviving]}
|
||||
)
|
||||
|
||||
|
||||
def _irv_round(compressed: pl.DataFrame, method="bigslow"):
|
||||
|
||||
if method == "bigslow":
|
||||
count_fn = _fst_counts_bigslow
|
||||
elif method == "smallfast":
|
||||
count_fn = _fst_counts_smallfast
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Error: _fst_counts method={method} not implemented."
|
||||
)
|
||||
|
||||
tally = count_fn(compressed)
|
||||
|
||||
eliminate = tally.filter(pl.col("count") == pl.col("count").min())[
|
||||
"first_choice"
|
||||
].to_list()
|
||||
|
||||
for e in eliminate:
|
||||
surviving = [c for c in compressed.columns if c not in ("count", e)]
|
||||
yield (
|
||||
compressed.select(surviving + ["count"])
|
||||
.group_by(surviving)
|
||||
.agg(pl.col("count").sum())
|
||||
)
|
||||
|
||||
|
||||
def _irv_winners(compressed, method="bigslow"):
|
||||
|
||||
surviving = [c for c in compressed.columns if c != "count"]
|
||||
if len(surviving) == 1:
|
||||
return set(surviving)
|
||||
|
||||
winners = set()
|
||||
for branch in _irv_round(compressed, method=method):
|
||||
winners |= _irv_winners(branch, method=method)
|
||||
return winners
|
||||
|
|
@ -3,7 +3,7 @@ import rustworkx as rwx
|
|||
from itertools import combinations
|
||||
|
||||
|
||||
def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||
def pmg_from_rcv_bigslow(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||
"""
|
||||
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
||||
|
||||
|
|
@ -34,10 +34,20 @@ def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
|||
pairs = list(combinations(candidates, 2))
|
||||
|
||||
for a, b in pairs:
|
||||
exprs.extend([
|
||||
pl.when(pl.col(a) < pl.col(b)).then(pl.col("count")).otherwise(0).sum().alias(f"{a}>{b}"),
|
||||
pl.when(pl.col(b) < pl.col(a)).then(pl.col("count")).otherwise(0).sum().alias(f"{b}>{a}")
|
||||
])
|
||||
exprs.extend(
|
||||
[
|
||||
pl.when(pl.col(a) < pl.col(b))
|
||||
.then(pl.col("count"))
|
||||
.otherwise(0)
|
||||
.sum()
|
||||
.alias(f"{a}>{b}"),
|
||||
pl.when(pl.col(b) < pl.col(a))
|
||||
.then(pl.col("count"))
|
||||
.otherwise(0)
|
||||
.sum()
|
||||
.alias(f"{b}>{a}"),
|
||||
]
|
||||
)
|
||||
|
||||
results = compressed.select(exprs).row(0, named=True)
|
||||
|
||||
|
|
@ -52,7 +62,8 @@ def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
|||
|
||||
return pmg
|
||||
|
||||
def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||
|
||||
def pmg_from_rcv_smallfast(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||
"""
|
||||
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
||||
|
||||
|
|
@ -79,17 +90,17 @@ def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
|||
|
||||
compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"})
|
||||
counts = compressed["count"].to_numpy()
|
||||
|
||||
|
||||
arr = compressed.drop("count").to_numpy()
|
||||
results = ((arr[:, :, None] < arr[:, None, :]) * counts[:, None, None]).sum(axis=0)
|
||||
|
||||
for i, a in enumerate(candidates):
|
||||
for j in range(i + 1, len(candidates)):
|
||||
b = candidates[j]
|
||||
|
||||
|
||||
a_wins = results[i, j]
|
||||
b_wins = results[j, i]
|
||||
|
||||
|
||||
if a_wins > b_wins:
|
||||
pmg.add_edge(nodes[a], nodes[b], int(a_wins - b_wins))
|
||||
elif b_wins > a_wins:
|
||||
|
|
@ -97,10 +108,11 @@ def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
|||
|
||||
return pmg
|
||||
|
||||
def pmg_from_rcv(ballots: pl.DataFrame, method="numpy") -> rwx.PyDiGraph:
|
||||
if method == "polars":
|
||||
return pmg_from_rcv_polars(ballots)
|
||||
elif method == "numpy":
|
||||
return pmg_from_rcv_numpy(ballots)
|
||||
|
||||
def pmg_from_rcv(ballots: pl.DataFrame, method="bigslow") -> rwx.PyDiGraph:
|
||||
if method == "bigslow":
|
||||
return pmg_from_rcv_bigslow(ballots)
|
||||
elif method == "smallfast":
|
||||
return pmg_from_rcv_smallfast(ballots)
|
||||
else:
|
||||
raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.")
|
||||
raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.")
|
||||
|
|
|
|||
|
|
@ -15,11 +15,16 @@ from rich.table import Table
|
|||
from rich.panel import Panel
|
||||
|
||||
import polars as pl
|
||||
from smithy import smith_set
|
||||
from smithy import smith_set, irv_set
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("ballots", type=click.Path(exists=True, dir_okay=False))
|
||||
@click.option(
|
||||
"--try-resolve-irv",
|
||||
is_flag=True,
|
||||
help="Try to reduce or resolve the Smith set by running all-paths IRV on the set.",
|
||||
)
|
||||
@click.option(
|
||||
"--show-ballots",
|
||||
"-b",
|
||||
|
|
@ -27,7 +32,7 @@ from smithy import smith_set
|
|||
help="Show relevant ballots (after selections).",
|
||||
)
|
||||
@click.option("--pretty", "-p", is_flag=True, help="Pretty-print output.")
|
||||
def cli(ballots: str, show_ballots=False, pretty=False) -> None:
|
||||
def cli(ballots: str, try_resolve_irv=False, show_ballots=False, pretty=False) -> None:
|
||||
"""
|
||||
Compute the Smith set from a box of ranked-choice ballots -- .csv or .xls(x).
|
||||
|
||||
|
|
@ -56,7 +61,7 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
|
|||
df = df.with_columns(
|
||||
[
|
||||
pl.col(c)
|
||||
.cast(pl.Utf8)
|
||||
.cast(pl.String)
|
||||
.str.strip_chars()
|
||||
.cast(pl.Int64, strict=False)
|
||||
.fill_null(0)
|
||||
|
|
@ -66,6 +71,9 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
|
|||
|
||||
# Compute Smith set
|
||||
smiths = smith_set(df)
|
||||
if len(smiths) > 1 and try_resolve_irv:
|
||||
irv_ballots = df.select(smiths)
|
||||
smiths = irv_set(irv_ballots)
|
||||
|
||||
if show_ballots and pretty:
|
||||
preview = Table(title="Ballot Box")
|
||||
|
|
@ -88,7 +96,9 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
|
|||
console.print(
|
||||
Panel.fit(
|
||||
"\n".join(f"• {c}" for c in smiths),
|
||||
title="Resulting Smith Set",
|
||||
title="Resulting IRV-resolved Smith Set"
|
||||
if (try_resolve_irv)
|
||||
else "Resulting Smith Set",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue