mirror of
https://github.com/Paillat-dev/esc-dramatic-unpause.git
synced 2026-01-02 01:06:21 +00:00
186 lines
5.7 KiB
Python
186 lines
5.7 KiB
Python
# Copyright (c) Paillat-dev
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import TypedDict
|
|
|
|
import orjson
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) # noqa: PTH120, PTH118
|
|
|
|
import tempfile
|
|
|
|
import pytest
|
|
from typer.testing import CliRunner
|
|
|
|
from src.__main__ import app
|
|
|
|
DATASET_BASE = Path(__file__).parent / "dataset" / "data"
|
|
DATASET_PATH = DATASET_BASE / "senior"
|
|
|
|
YEARS_TO_TEST = list(range(2016, 2026))
|
|
|
|
# Rest of world televote was introduced in Eurovision 2023
|
|
REST_OF_WORLD_VOTE_YEAR = 2023
|
|
|
|
|
|
class ESCData(TypedDict):
|
|
"""TypedDict for ESC data."""
|
|
|
|
jury: dict[str, int]
|
|
televote: dict[str, int]
|
|
winner: str
|
|
participating_countries: int
|
|
|
|
with open(DATASET_BASE / "countries.json", "rb") as f:
|
|
COUNTRY_NAMES: dict[str, str] = orjson.loads(f.read())
|
|
|
|
def get_country_name(country_code: str) -> str:
|
|
"""Convert country code to full country name."""
|
|
|
|
return COUNTRY_NAMES.get(country_code.upper(), country_code)
|
|
|
|
|
|
def get_country_mapping(year: int) -> dict[int, str]:
|
|
"""
|
|
Map contestant IDs to country codes by reading the contestants directory.
|
|
Returns a dict: {contestant_id: country_code}
|
|
"""
|
|
contestants_dir = DATASET_PATH / str(year) / "contestants"
|
|
country_mapping: dict[int, str] = {}
|
|
|
|
if not contestants_dir.exists():
|
|
return country_mapping
|
|
|
|
for item in contestants_dir.iterdir():
|
|
if item.is_dir():
|
|
# Directory name format: N_XX where N is contestant ID and XX is country code
|
|
dir_name = item.name
|
|
if "_" in dir_name:
|
|
contestant_id_str, country_code = dir_name.split("_", 1)
|
|
try:
|
|
contestant_id = int(contestant_id_str)
|
|
country_mapping[contestant_id] = country_code.upper()
|
|
except ValueError:
|
|
pass
|
|
|
|
return country_mapping
|
|
|
|
|
|
def count_participating_countries(year: int) -> int:
|
|
"""Count the number of directories in the contestants folder."""
|
|
contestants_dir = DATASET_PATH / str(year) / "contestants"
|
|
if not contestants_dir.exists():
|
|
return 0
|
|
return sum(1 for item in contestants_dir.iterdir() if item.is_dir())
|
|
|
|
|
|
def parse_year_data(year: int) -> ESCData:
|
|
"""Parse the data for a single year from the dataset."""
|
|
final_json_path = DATASET_PATH / str(year) / "rounds" / "final.json"
|
|
|
|
if not final_json_path.exists():
|
|
pytest.skip(f"Data not found for year {year}")
|
|
|
|
# Load the final.json data
|
|
with final_json_path.open("rb") as f:
|
|
data = orjson.loads(f.read())
|
|
|
|
country_mapping = get_country_mapping(year)
|
|
|
|
# Count participating countries
|
|
participating_countries = count_participating_countries(year)
|
|
|
|
jury_scores = {}
|
|
televote_scores = {}
|
|
winner = None
|
|
|
|
performances = data.get("performances", [])
|
|
|
|
for performance in performances:
|
|
contestant_id = performance.get("contestantId")
|
|
place = performance.get("place")
|
|
|
|
country_code = country_mapping.get(contestant_id)
|
|
if not country_code:
|
|
continue
|
|
|
|
country_name = get_country_name(country_code)
|
|
|
|
if place == 1:
|
|
winner = country_name
|
|
|
|
# Extract scores
|
|
scores = performance.get("scores", [])
|
|
for score in scores:
|
|
score_name = score.get("name")
|
|
points = score.get("points", 0)
|
|
|
|
if score_name == "jury":
|
|
jury_scores[country_name] = points
|
|
elif score_name == "public":
|
|
televote_scores[country_name] = points
|
|
|
|
# Sort by points (descending)
|
|
jury_scores = dict(sorted(jury_scores.items(), key=lambda x: x[1], reverse=True))
|
|
televote_scores = dict(sorted(televote_scores.items(), key=lambda x: x[1], reverse=True))
|
|
|
|
return {
|
|
"jury": jury_scores,
|
|
"televote": televote_scores,
|
|
"winner": winner if winner else "Unknown",
|
|
"participating_countries": participating_countries,
|
|
}
|
|
|
|
|
|
# Parse all years data
|
|
ESC_DATA: dict[int, ESCData] = {}
|
|
for year in YEARS_TO_TEST:
|
|
try:
|
|
ESC_DATA[year] = parse_year_data(year)
|
|
except Exception as e:
|
|
print(f"Warning: Could not parse data for year {year}: {e}") # noqa: T201
|
|
|
|
|
|
TADA = "🎉"
|
|
|
|
|
|
@pytest.mark.parametrize(("year", "data"), ESC_DATA.items())
|
|
def test_esc_grand_final(year: int, data: ESCData) -> None:
|
|
"""Test the ESC grand final for a given year."""
|
|
jury_scores: dict[str, int] = data["jury"]
|
|
televote_scores: dict[str, int] = data["televote"]
|
|
expected_winner: str = data["winner"]
|
|
participating_countries: int = data["participating_countries"]
|
|
|
|
with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as f:
|
|
for country, score in jury_scores.items():
|
|
f.write(f"{country} {score}\n")
|
|
|
|
inputs: list[str] = []
|
|
for country in reversed(list(jury_scores.keys())):
|
|
inputs.append(str(televote_scores[country])) # noqa: PERF401
|
|
inputs.append("y") # to confirm the winner
|
|
|
|
# Determine if rest of world vote should be included based on the year
|
|
rest_of_world_vote = year >= REST_OF_WORLD_VOTE_YEAR
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
[
|
|
"--jury-path", f.name,
|
|
"--participating-countries", str(participating_countries),
|
|
"--rest-of-world-vote" if rest_of_world_vote else "--no-rest-of-world-vote",
|
|
],
|
|
input="\n".join(inputs),
|
|
)
|
|
|
|
try:
|
|
actual = result.output.split(TADA)[1].strip().split()[0]
|
|
except Exception:
|
|
pytest.fail(f"Could not parse winner from output:\n{result.output}", pytrace=False)
|
|
|
|
assert actual == expected_winner, f"For {year}, expected winner {expected_winner} but got {actual!r}" |