39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
from __future__ import annotations
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from float_index_sort import argsort_float32_bucket, argsort_float32_radix
|
|
|
|
|
|
class SortingAlgorithmTests(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.algorithms = {
|
|
"radix": argsort_float32_radix,
|
|
"bucket": argsort_float32_bucket,
|
|
}
|
|
|
|
def assert_matches_numpy(self, values: np.ndarray) -> None:
|
|
expected = np.argsort(values, kind="stable")
|
|
for name, sorter in self.algorithms.items():
|
|
with self.subTest(algorithm=name):
|
|
actual = sorter(values)
|
|
self.assertTrue(np.array_equal(actual, expected))
|
|
|
|
def test_small_mixed_values(self) -> None:
|
|
values = np.array([3.0, -2.5, 1.0, -0.0, 0.0, 1.0, -5.0, np.inf], dtype=np.float32)
|
|
self.assert_matches_numpy(values)
|
|
|
|
def test_descending_values(self) -> None:
|
|
values = np.linspace(9.0, -9.0, 513, dtype=np.float32)
|
|
self.assert_matches_numpy(values)
|
|
|
|
def test_random_values(self) -> None:
|
|
rng = np.random.default_rng(42)
|
|
values = rng.normal(loc=0.0, scale=4.0, size=4096).astype(np.float32)
|
|
self.assert_matches_numpy(values)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |