aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastiano Tronto <sebastiano@tronto.net>2024-12-21 12:30:41 +0100
committerSebastiano Tronto <sebastiano@tronto.net>2024-12-21 12:30:41 +0100
commit6eaa7b33aa8690f5ba4cee0897d2d05c71c27c20 (patch)
treeef686d3341f046ffe69945c3af40ec5dd4cad321
downloadzmodn-6eaa7b33aa8690f5ba4cee0897d2d05c71c27c20.tar.gz
zmodn-6eaa7b33aa8690f5ba4cee0897d2d05c71c27c20.zip
Initial commit
-rw-r--r--README.md10
-rwxr-xr-xtest128
-rw-r--r--zmodn.h71
3 files changed, 209 insertions, 0 deletions
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..900a296
--- /dev/null
+++ b/README.md
@@ -0,0 +1,10 @@
1# ZmodN - A simple library for modular arithmetic
2
3Usage:
4
51. `#include "zmodn.h"` in your project
62. enjoy
7
8# Development
9
10Run `chmod +x test` and then `./test` to run tests.
diff --git a/test b/test
new file mode 100755
index 0000000..2930e89
--- /dev/null
+++ b/test
@@ -0,0 +1,128 @@
1#if 0
2
3cc=${CC:-g++}
4bin="$(mktemp)"
5${cc} -x c++ -std=c++20 -o "$bin" "$(realpath $0)"
6"$bin"
7
8exit 0
9#endif
10
11#include "zmodn.h"
12#include <concepts>
13#include <functional>
14#include <iostream>
15#include <optional>
16
17template<typename S, typename T>
18requires std::convertible_to<S, T> || std::convertible_to<T, S>
19void assert_equal(S actual, T expected) {
20 if (actual != expected) {
21 std::cout << "Error!" << std::endl;
22 std::cout << "Expected: " << expected << std::endl;
23 std::cout << "But got: " << actual << std::endl;
24 exit(1);
25 }
26}
27
28class Test {
29public:
30 std::string name;
31 std::function<void()> f;
32} tests[] = {
33{
34 .name = "Constructor 2 mod 3",
35 .f = []() {
36 Zmod<3> two = Zmod<3>(2);
37 assert_equal(two.toint64(), INT64_C(2));
38 }
39},
40{
41 .name = "Constructor -7 mod 3",
42 .f = []() {
43 Zmod<3> z = -7;
44 assert_equal(z, Zmod<3>(2));
45 }
46},
47{
48 .name = "1+1 mod 2",
49 .f = []() {
50 auto oneplusone = Zmod<2>(1) + Zmod<2>(1);
51 assert_equal(oneplusone, Zmod<2>(0));
52 }
53},
54{
55 .name = "2 -= 5 (mod 4)",
56 .f = []() {
57 Zmod<4> z = 2;
58 auto diff = (z -= 5);
59 assert_equal(z, Zmod<4>(1));
60 assert_equal(diff, Zmod<4>(1));
61 }
62},
63{
64 .name = "Inverse of 0 mod 2",
65 .f = []() {
66 Zmod<2> z = 0;
67 auto inv = z.inverse();
68 assert_equal(inv.has_value(), false);
69 }
70},
71{
72 .name = "Inverse of 1 mod 2",
73 .f = []() {
74 Zmod<2> z = 1;
75 auto inv = z.inverse();
76 assert_equal(inv.has_value(), true);
77 assert_equal(inv.value(), Zmod<2>(1));
78 }
79},
80{
81 .name = "Inverse of 5 mod 7",
82 .f = []() {
83 Zmod<7> z = 5;
84 auto inv = z.inverse();
85 assert_equal(inv.has_value(), true);
86 assert_equal(inv.value(), Zmod<7>(3));
87 }
88},
89{
90 .name = "Inverse of 4 mod 12",
91 .f = []() {
92 Zmod<12> z = 4;
93 auto inv = z.inverse();
94 assert_equal(inv.has_value(), false);
95 }
96},
97{
98 .name = "4 / 7 (mod 12)",
99 .f = []() {
100 Zmod<12> n = 4;
101 Zmod<12> d = 7;
102 auto inv = n / d;
103 assert_equal(inv.has_value(), true);
104 assert_equal(inv.value(), Zmod<12>(4));
105 }
106},
107{
108 .name = "4 /= 7 (mod 12)",
109 .f = []() {
110 Zmod<12> n = 4;
111 Zmod<12> d = 7;
112 auto inv = (n /= d);
113 assert_equal(inv.has_value(), true);
114 assert_equal(inv.value(), Zmod<12>(4));
115 assert_equal(n, Zmod<12>(4));
116 }
117},
118};
119
120int main() {
121 for (auto t : tests) {
122 std::cout << t.name << ": ";
123 t.f();
124 std::cout << "OK" << std::endl;
125 }
126 std::cout << "All tests passed" << std::endl;
127 return 0;
128}
diff --git a/zmodn.h b/zmodn.h
new file mode 100644
index 0000000..8981257
--- /dev/null
+++ b/zmodn.h
@@ -0,0 +1,71 @@
1#ifndef ZMODN_H
2#define ZMODN_H
3
4#include <cstdint>
5#include <iostream>
6#include <optional>
7#include <tuple>
8
9std::tuple<int64_t, int64_t, int64_t> extended_gcd(int64_t, int64_t);
10
11template<int64_t N> requires(N > 1)
12class Zmod {
13public:
14 Zmod(int64_t z) : int64{(z%N + N) % N} {}
15 int64_t toint64() const { return int64; }
16
17 Zmod operator+(const Zmod& z) const { return int64 + z.int64; }
18 Zmod operator-(const Zmod& z) const { return int64 - z.int64; }
19 Zmod operator*(const Zmod& z) const { return int64 * z.int64; }
20 Zmod operator+=(const Zmod& z) { return int64 += z.int64; }
21 Zmod operator-=(const Zmod& z) { return int64 -= z.int64; }
22 Zmod operator*=(const Zmod& z) { return int64 *= z.int64; }
23
24 bool operator==(const Zmod& z) const { return int64 == z.int64; }
25 bool operator!=(const Zmod& z) const { return int64 != z.int64; }
26
27 std::optional<Zmod> inverse() const {
28 auto [g, a, _] = extended_gcd(int64, N);
29 return g == 1 ? Zmod(a) : std::optional<Zmod>{};
30 }
31
32 std::optional<Zmod> operator/(const Zmod& d) const {
33 auto i = d.inverse();
34 return i ? (*this) * i.value() : i;
35 }
36
37 std::optional<Zmod> operator/=(const Zmod& d) {
38 auto q = *this / d;
39 return q ? (*this = q.value()) : q;
40 }
41
42 friend std::ostream& operator<<(std::ostream& os, const Zmod<N>& z) {
43 return os << "(" << z.int64 << " mod " << N << ")";
44 }
45private:
46 int64_t int64;
47};
48
49void swapdiv(int64_t& oldx, int64_t& x, int64_t q) {
50 int64_t tmp = x;
51 x = oldx - q*tmp;
52 oldx = tmp;
53}
54
55std::tuple<int64_t, int64_t, int64_t> extended_gcd(int64_t a, int64_t b) {
56 int64_t oldr = a;
57 int64_t r = b;
58 int64_t olds = 1;
59 int64_t s = 0;
60 int64_t oldt = 0;
61 int64_t t = 1;
62 while (r != 0) {
63 auto q = oldr / r;
64 swapdiv(oldr, r, q);
65 swapdiv(olds, s, q);
66 swapdiv(oldt, t, q);
67 }
68 return {oldr, olds, oldt};
69}
70
71#endif