mirror of
https://github.com/tgorordo/smithy.git
synced 2026-06-05 16:22:15 -07:00
add IRV smith-set resolution method
This commit is contained in:
parent
fc0569e252
commit
e0d0134fdd
4 changed files with 153 additions and 19 deletions
|
|
@ -3,6 +3,7 @@ import rustworkx as rwx
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
|
|
||||||
from .rcv import pmg_from_rcv
|
from .rcv import pmg_from_rcv
|
||||||
|
from .irv import irv_from_rcv
|
||||||
|
|
||||||
|
|
||||||
def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]:
|
def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]:
|
||||||
|
|
@ -67,3 +68,12 @@ def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"`smith_set` ballotkind={ballotkind} is not implemented."
|
f"`smith_set` ballotkind={ballotkind} is not implemented."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def irv_set(df: pl.DataFrame, ballotkind="rcv") -> list:
|
||||||
|
if ballotkind == "rcv":
|
||||||
|
return irv_from_rcv(df)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"`irv_set` ballotkind={ballotkind} is not implemented."
|
||||||
|
)
|
||||||
|
|
|
||||||
102
src/smithy/irv.py
Normal file
102
src/smithy/irv.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
import polars as pl
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def irv_from_rcv(ballots: pl.DataFrame, method: str = "bigslow") -> list[str]:
|
||||||
|
"""
|
||||||
|
Compute the set of all-paths IRV winners from an RCV ballot.
|
||||||
|
|
||||||
|
parameters
|
||||||
|
---
|
||||||
|
ballots: pl.DataFrame
|
||||||
|
An RCV table of ballots.
|
||||||
|
|
||||||
|
method: str
|
||||||
|
Either "bigslow" or "smallfast" for selecting an internal method for counting
|
||||||
|
first-choices during IRV rounds. Defaults to "bigslow" but you can use "smallfast"
|
||||||
|
so long as the table of ballots is expected to fit in a reasonable numpy array (after compression).
|
||||||
|
|
||||||
|
returns
|
||||||
|
---
|
||||||
|
winners: list[srt]
|
||||||
|
A lexicographically sorted list of IRV winners. If a candidate wins every elimination path
|
||||||
|
then this set will contain only one entry, otherwise it will contain all candidates that win
|
||||||
|
at least one IRV elimination path.
|
||||||
|
"""
|
||||||
|
compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"})
|
||||||
|
return sorted(_irv_winners(compressed, method=method))
|
||||||
|
|
||||||
|
|
||||||
|
def _fst_counts_bigslow(compressed: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
|
||||||
|
surviving = [c for c in compressed.columns if c != "count"]
|
||||||
|
|
||||||
|
fstcexpr = (
|
||||||
|
pl.concat_list([pl.col(c) for c in surviving])
|
||||||
|
.list.arg_min().map_elements(lambda i: surviving[i], return_dtype=pl.String).alias("first_choice")
|
||||||
|
)
|
||||||
|
|
||||||
|
tally = (
|
||||||
|
compressed.with_columns(fstcexpr)
|
||||||
|
.group_by("first_choice")
|
||||||
|
.agg(pl.col("count").sum())
|
||||||
|
.filter(pl.col("first_choice").is_not_null())
|
||||||
|
)
|
||||||
|
return tally
|
||||||
|
|
||||||
|
|
||||||
|
def _fst_counts_smallfast(compressed: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
|
||||||
|
surviving = [c for c in compressed.columns if c != "count"]
|
||||||
|
|
||||||
|
a = compressed.select(surviving).to_numpy()
|
||||||
|
|
||||||
|
cs = compressed["count"].to_numpy()
|
||||||
|
|
||||||
|
fstc_idxs = np.argmin(a, axis=1)
|
||||||
|
|
||||||
|
tally = {c: 0 for c in surviving}
|
||||||
|
for i, c in zip(fstc_idxs, cs):
|
||||||
|
tally[surviving[i]] += int(c)
|
||||||
|
|
||||||
|
return pl.DataFrame(
|
||||||
|
{"first_choice": surviving, "count": [tally[c] for c in surviving]}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _irv_round(compressed: pl.DataFrame, method="bigslow"):
|
||||||
|
|
||||||
|
if method == "bigslow":
|
||||||
|
count_fn = _fst_counts_bigslow
|
||||||
|
elif method == "smallfast":
|
||||||
|
count_fn = _fst_counts_smallfast
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Error: _fst_counts method={method} not implemented."
|
||||||
|
)
|
||||||
|
|
||||||
|
tally = count_fn(compressed)
|
||||||
|
|
||||||
|
eliminate = tally.filter(pl.col("count") == pl.col("count").min())[
|
||||||
|
"first_choice"
|
||||||
|
].to_list()
|
||||||
|
|
||||||
|
for e in eliminate:
|
||||||
|
surviving = [c for c in compressed.columns if c not in ("count", e)]
|
||||||
|
yield (
|
||||||
|
compressed.select(surviving + ["count"])
|
||||||
|
.group_by(surviving)
|
||||||
|
.agg(pl.col("count").sum())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _irv_winners(compressed, method="bigslow"):
|
||||||
|
|
||||||
|
surviving = [c for c in compressed.columns if c != "count"]
|
||||||
|
if len(surviving) == 1:
|
||||||
|
return set(surviving)
|
||||||
|
|
||||||
|
winners = set()
|
||||||
|
for branch in _irv_round(compressed, method=method):
|
||||||
|
winners |= _irv_winners(branch, method=method)
|
||||||
|
return winners
|
||||||
|
|
@ -3,7 +3,7 @@ import rustworkx as rwx
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
|
|
||||||
|
|
||||||
def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
def pmg_from_rcv_bigslow(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||||
"""
|
"""
|
||||||
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
||||||
|
|
||||||
|
|
@ -34,10 +34,20 @@ def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||||
pairs = list(combinations(candidates, 2))
|
pairs = list(combinations(candidates, 2))
|
||||||
|
|
||||||
for a, b in pairs:
|
for a, b in pairs:
|
||||||
exprs.extend([
|
exprs.extend(
|
||||||
pl.when(pl.col(a) < pl.col(b)).then(pl.col("count")).otherwise(0).sum().alias(f"{a}>{b}"),
|
[
|
||||||
pl.when(pl.col(b) < pl.col(a)).then(pl.col("count")).otherwise(0).sum().alias(f"{b}>{a}")
|
pl.when(pl.col(a) < pl.col(b))
|
||||||
])
|
.then(pl.col("count"))
|
||||||
|
.otherwise(0)
|
||||||
|
.sum()
|
||||||
|
.alias(f"{a}>{b}"),
|
||||||
|
pl.when(pl.col(b) < pl.col(a))
|
||||||
|
.then(pl.col("count"))
|
||||||
|
.otherwise(0)
|
||||||
|
.sum()
|
||||||
|
.alias(f"{b}>{a}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
results = compressed.select(exprs).row(0, named=True)
|
results = compressed.select(exprs).row(0, named=True)
|
||||||
|
|
||||||
|
|
@ -52,7 +62,8 @@ def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||||
|
|
||||||
return pmg
|
return pmg
|
||||||
|
|
||||||
def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
|
||||||
|
def pmg_from_rcv_smallfast(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||||
"""
|
"""
|
||||||
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
||||||
|
|
||||||
|
|
@ -79,17 +90,17 @@ def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||||
|
|
||||||
compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"})
|
compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"})
|
||||||
counts = compressed["count"].to_numpy()
|
counts = compressed["count"].to_numpy()
|
||||||
|
|
||||||
arr = compressed.drop("count").to_numpy()
|
arr = compressed.drop("count").to_numpy()
|
||||||
results = ((arr[:, :, None] < arr[:, None, :]) * counts[:, None, None]).sum(axis=0)
|
results = ((arr[:, :, None] < arr[:, None, :]) * counts[:, None, None]).sum(axis=0)
|
||||||
|
|
||||||
for i, a in enumerate(candidates):
|
for i, a in enumerate(candidates):
|
||||||
for j in range(i + 1, len(candidates)):
|
for j in range(i + 1, len(candidates)):
|
||||||
b = candidates[j]
|
b = candidates[j]
|
||||||
|
|
||||||
a_wins = results[i, j]
|
a_wins = results[i, j]
|
||||||
b_wins = results[j, i]
|
b_wins = results[j, i]
|
||||||
|
|
||||||
if a_wins > b_wins:
|
if a_wins > b_wins:
|
||||||
pmg.add_edge(nodes[a], nodes[b], int(a_wins - b_wins))
|
pmg.add_edge(nodes[a], nodes[b], int(a_wins - b_wins))
|
||||||
elif b_wins > a_wins:
|
elif b_wins > a_wins:
|
||||||
|
|
@ -97,10 +108,11 @@ def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||||
|
|
||||||
return pmg
|
return pmg
|
||||||
|
|
||||||
def pmg_from_rcv(ballots: pl.DataFrame, method="numpy") -> rwx.PyDiGraph:
|
|
||||||
if method == "polars":
|
def pmg_from_rcv(ballots: pl.DataFrame, method="bigslow") -> rwx.PyDiGraph:
|
||||||
return pmg_from_rcv_polars(ballots)
|
if method == "bigslow":
|
||||||
elif method == "numpy":
|
return pmg_from_rcv_bigslow(ballots)
|
||||||
return pmg_from_rcv_numpy(ballots)
|
elif method == "smallfast":
|
||||||
|
return pmg_from_rcv_smallfast(ballots)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.")
|
raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.")
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,16 @@ from rich.table import Table
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
from smithy import smith_set
|
from smithy import smith_set, irv_set
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.argument("ballots", type=click.Path(exists=True, dir_okay=False))
|
@click.argument("ballots", type=click.Path(exists=True, dir_okay=False))
|
||||||
|
@click.option(
|
||||||
|
"--try-resolve-irv",
|
||||||
|
is_flag=True,
|
||||||
|
help="Try to reduce or resolve the Smith set by running all-paths IRV on the set.",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--show-ballots",
|
"--show-ballots",
|
||||||
"-b",
|
"-b",
|
||||||
|
|
@ -27,7 +32,7 @@ from smithy import smith_set
|
||||||
help="Show relevant ballots (after selections).",
|
help="Show relevant ballots (after selections).",
|
||||||
)
|
)
|
||||||
@click.option("--pretty", "-p", is_flag=True, help="Pretty-print output.")
|
@click.option("--pretty", "-p", is_flag=True, help="Pretty-print output.")
|
||||||
def cli(ballots: str, show_ballots=False, pretty=False) -> None:
|
def cli(ballots: str, try_resolve_irv=False, show_ballots=False, pretty=False) -> None:
|
||||||
"""
|
"""
|
||||||
Compute the Smith set from a box of ranked-choice ballots -- .csv or .xls(x).
|
Compute the Smith set from a box of ranked-choice ballots -- .csv or .xls(x).
|
||||||
|
|
||||||
|
|
@ -56,7 +61,7 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
|
||||||
df = df.with_columns(
|
df = df.with_columns(
|
||||||
[
|
[
|
||||||
pl.col(c)
|
pl.col(c)
|
||||||
.cast(pl.Utf8)
|
.cast(pl.String)
|
||||||
.str.strip_chars()
|
.str.strip_chars()
|
||||||
.cast(pl.Int64, strict=False)
|
.cast(pl.Int64, strict=False)
|
||||||
.fill_null(0)
|
.fill_null(0)
|
||||||
|
|
@ -66,6 +71,9 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
|
||||||
|
|
||||||
# Compute Smith set
|
# Compute Smith set
|
||||||
smiths = smith_set(df)
|
smiths = smith_set(df)
|
||||||
|
if len(smiths) > 1 and try_resolve_irv:
|
||||||
|
irv_ballots = df.select(smiths)
|
||||||
|
smiths = irv_set(irv_ballots)
|
||||||
|
|
||||||
if show_ballots and pretty:
|
if show_ballots and pretty:
|
||||||
preview = Table(title="Ballot Box")
|
preview = Table(title="Ballot Box")
|
||||||
|
|
@ -88,7 +96,9 @@ def cli(ballots: str, show_ballots=False, pretty=False) -> None:
|
||||||
console.print(
|
console.print(
|
||||||
Panel.fit(
|
Panel.fit(
|
||||||
"\n".join(f"• {c}" for c in smiths),
|
"\n".join(f"• {c}" for c in smiths),
|
||||||
title="Resulting Smith Set",
|
title="Resulting IRV-resolved Smith Set"
|
||||||
|
if (try_resolve_irv)
|
||||||
|
else "Resulting Smith Set",
|
||||||
border_style="green",
|
border_style="green",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue