mirror of
https://github.com/tgorordo/smithy.git
synced 2026-06-05 16:22:15 -07:00
update to SCC algorithm, tidy internal names
This commit is contained in:
parent
fcf2505820
commit
674fdc1fe9
7 changed files with 42 additions and 45 deletions
3
justfile
3
justfile
|
|
@ -7,9 +7,6 @@ run *args:
|
|||
marimo:
|
||||
uv run marimo --edit
|
||||
|
||||
example:
|
||||
uv run python src/main.py test/test_ballot.csv
|
||||
|
||||
format:
|
||||
uv run ruff format src test
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ authors = [
|
|||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"numpy>=2.4.6",
|
||||
"polars>=1.40.1",
|
||||
"rustworkx>=0.17.1",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -96,9 +96,7 @@ nbformat==5.10.4
|
|||
nodeenv==1.10.0
|
||||
# via pyright
|
||||
numpy==2.4.6
|
||||
# via
|
||||
# smithy (pyproject.toml)
|
||||
# rustworkx
|
||||
# via rustworkx
|
||||
openai==2.37.0
|
||||
# via pydantic-ai-slim
|
||||
opentelemetry-api==1.42.0
|
||||
|
|
|
|||
|
|
@ -1,16 +1,17 @@
|
|||
import polars as pl
|
||||
import rustworkx as rwx
|
||||
from itertools import combinations
|
||||
|
||||
from .rcv import pairmaj_from_rcv
|
||||
from .rcv import pmg_from_rcv
|
||||
|
||||
|
||||
def smith_set_brutefrom_pairmaj(pairmaj_graph: dict[str, set[str]]) -> list:
|
||||
def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]:
|
||||
"""
|
||||
Brute-force the Smith set from a pairwise majority winner graph.
|
||||
Find the Smith set from a pairwise majority graph.
|
||||
|
||||
parameters
|
||||
---
|
||||
pairmaj_graph: dict[str, set[str]]
|
||||
pmg: rwx.PyDiGraph
|
||||
A graph whose nodes correspond to candidates and (directed) edges show
|
||||
which candidates they beat pairwise.
|
||||
|
||||
|
|
@ -22,28 +23,18 @@ def smith_set_brutefrom_pairmaj(pairmaj_graph: dict[str, set[str]]) -> list:
|
|||
(single Majority winner), the Smith set will contain that single candidate.
|
||||
"""
|
||||
|
||||
candidates = set(pairmaj_graph.keys())
|
||||
size = len(candidates)
|
||||
sccs = rwx.strongly_connected_components(pmg)
|
||||
|
||||
for size in range(1, len(candidates) + 1):
|
||||
for sub in combinations(candidates, size):
|
||||
subset = set(sub)
|
||||
out = set(candidates) - subset
|
||||
cg = rwx.condensation(pmg, sccs)
|
||||
|
||||
dom = True
|
||||
src_sccs = [nd for nd in cg.node_indices() if cg.in_degree(nd) == 0]
|
||||
|
||||
for member in subset:
|
||||
if not out.issubset(pairmaj_graph[member]):
|
||||
dom = False
|
||||
break
|
||||
smith_set = sorted([c for scc in src_sccs for c in cg[scc]])
|
||||
|
||||
if dom:
|
||||
return sorted(subset)
|
||||
|
||||
return []
|
||||
return smith_set
|
||||
|
||||
|
||||
def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
|
||||
def smith_set_from_rcv(ballots: pl.DataFrame) -> list:
|
||||
"""
|
||||
Compute the Smith set from a Ranked-Choice ballot.
|
||||
|
||||
|
|
@ -52,7 +43,7 @@ def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
|
|||
|
||||
parameters
|
||||
---
|
||||
df : pl.DataFrame
|
||||
ballots : pl.DataFrame
|
||||
A Polars DataFrame representing ballots. Each column is a candidate and each
|
||||
row is is a voter's ranking of the candidates. Lower numbers indicate higher
|
||||
preference (1 = top-choice).
|
||||
|
|
@ -66,7 +57,8 @@ def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
|
|||
|
||||
"""
|
||||
|
||||
return smith_set_brutefrom_pairmaj(pairmaj_from_rcv(rcv_ballots))
|
||||
# return smith_set_brutefrom_pairmaj(pairmaj_from_rcv(rcv_ballots))
|
||||
return ss_from_pmg(pmg_from_rcv(ballots))
|
||||
|
||||
|
||||
def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import polars as pl
|
||||
import rustworkx as rwx
|
||||
from itertools import combinations
|
||||
|
||||
|
||||
def pairmaj_from_rcv(rcv_ballots: pl.DataFrame) -> dict[str, set[str]]:
|
||||
def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||
"""
|
||||
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
||||
|
||||
|
|
@ -15,27 +16,38 @@ def pairmaj_from_rcv(rcv_ballots: pl.DataFrame) -> dict[str, set[str]]:
|
|||
|
||||
returns
|
||||
---
|
||||
pairmaj_graph: dict[str, set[str]]
|
||||
nodes: dict[str, int]
|
||||
A dictionary of candidate names to associated node ids.
|
||||
|
||||
pwm_graph: rwx.PyDiGraph
|
||||
A pairwise majority winner graph whose nodes correspond to candidates and
|
||||
(directed) edges show which candidates they beat pairwise.
|
||||
"""
|
||||
candidates = rcv_ballots.columns
|
||||
candidates = ballots.columns
|
||||
|
||||
pairmaj_graph: dict[str, set[str]] = {c: set() for c in candidates}
|
||||
pmg = rwx.PyDiGraph()
|
||||
nodes = {c: pmg.add_node(c) for c in candidates}
|
||||
|
||||
for a, b in combinations(candidates, 2):
|
||||
result = rcv_ballots.select(
|
||||
exprs = []
|
||||
pairs = list(combinations(candidates, 2))
|
||||
|
||||
for a, b in pairs:
|
||||
exprs.extend(
|
||||
[
|
||||
(pl.col(a) < pl.col(b)).sum().alias("a_wins"),
|
||||
(pl.col(b) < pl.col(a)).sum().alias("b_wins"),
|
||||
(pl.col(a) < pl.col(b)).sum().alias(f"{a}>{b}"),
|
||||
(pl.col(b) < pl.col(a)).sum().alias(f"{b}>{a}"),
|
||||
]
|
||||
).row(0)
|
||||
)
|
||||
|
||||
a_wins, b_wins = result
|
||||
results = ballots.select(exprs).row(0, named=True)
|
||||
|
||||
for a, b in pairs:
|
||||
a_wins = results[f"{a}>{b}"]
|
||||
b_wins = results[f"{b}>{a}"]
|
||||
|
||||
if a_wins > b_wins:
|
||||
pairmaj_graph[a].add(b)
|
||||
pmg.add_edge(nodes[a], nodes[b], a_wins - b_wins)
|
||||
elif b_wins > a_wins:
|
||||
pairmaj_graph[b].add(a)
|
||||
pmg.add_edge(nodes[b], nodes[a], b_wins - a_wins)
|
||||
|
||||
return pairmaj_graph
|
||||
return pmg
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
# dependencies = [
|
||||
# "click>=8.4.1",
|
||||
# "rich>=15.0.0",
|
||||
# "polars>=1.40.1"
|
||||
# "polars>=1.40.1",
|
||||
# "rustworkx>=0.17.1"
|
||||
# ]
|
||||
# ///
|
||||
import sys, io
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -1556,7 +1556,6 @@ name = "smithy"
|
|||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
{ name = "polars" },
|
||||
{ name = "rustworkx" },
|
||||
]
|
||||
|
|
@ -1584,7 +1583,6 @@ dev = [
|
|||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "click", marker = "extra == 'cli'", specifier = ">=8.4.0" },
|
||||
{ name = "numpy", specifier = ">=2.4.6" },
|
||||
{ name = "polars", specifier = ">=1.40.1" },
|
||||
{ name = "rich", marker = "extra == 'cli'", specifier = ">=15.0.0" },
|
||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue