aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastiano Tronto <sebastiano@tronto.net>2024-12-21 17:22:09 +0100
committerSebastiano Tronto <sebastiano@tronto.net>2024-12-21 17:22:09 +0100
commit7e705a3342bc36f9f02b4e8e0050bc5842a62f9b (patch)
treee496ef743d63168bc1f7a5d57167a0229889b396
parentb0d8a2f5f5718ba3fea628cad7bcbbc76c6a140c (diff)
downloadzmodn-7e705a3342bc36f9f02b4e8e0050bc5842a62f9b.tar.gz
zmodn-7e705a3342bc36f9f02b4e8e0050bc5842a62f9b.zip
Make N's type generic for possible use with bigint
-rwxr-xr-xtest6
-rw-r--r--zmodn.h44
2 files changed, 26 insertions, 24 deletions
diff --git a/test b/test
index 3256a12..b36c238 100755
--- a/test
+++ b/test
@@ -34,7 +34,7 @@ public:
34 .name = "Constructor 2 mod 3", 34 .name = "Constructor 2 mod 3",
35 .f = []() { 35 .f = []() {
36 Zmod<3> two = Zmod<3>(2); 36 Zmod<3> two = Zmod<3>(2);
37 assert_equal(two.toint64(), INT64_C(2)); 37 assert_equal(two.toint(), INT64_C(2));
38 } 38 }
39}, 39},
40{ 40{
@@ -121,7 +121,7 @@ public:
121 Zmod<10> n = 8; 121 Zmod<10> n = 8;
122 Zmod<10> m = 9; 122 Zmod<10> m = 9;
123 auto prod = m * n; 123 auto prod = m * n;
124 assert_equal(prod.toint64(), 2); 124 assert_equal(prod.toint(), 2);
125 } 125 }
126}, 126},
127{ 127{
@@ -130,7 +130,7 @@ public:
130 Zmod<10> n = 8; 130 Zmod<10> n = 8;
131 Zmod<10> m = 9; 131 Zmod<10> m = 9;
132 n *= m; 132 n *= m;
133 assert_equal(n.toint64(), 2); 133 assert_equal(n.toint(), 2);
134 } 134 }
135}, 135},
136}; 136};
diff --git a/zmodn.h b/zmodn.h
index f161ad1..b3b128e 100644
--- a/zmodn.h
+++ b/zmodn.h
@@ -5,27 +5,35 @@
5#include <iostream> 5#include <iostream>
6#include <optional> 6#include <optional>
7#include <tuple> 7#include <tuple>
8#include <type_traits>
8 9
9std::tuple<int64_t, int64_t, int64_t> extended_gcd(int64_t, int64_t); 10template<typename INT>
11requires std::is_integral_v<INT>
12std::tuple<INT, INT, INT> extended_gcd(INT a, INT b) {
13 if (b == 0) return {a, 1, 0};
14 auto [g, x, y] = extended_gcd(b, a%b);
15 return {g, y, x - y*(a/b)};
16}
10 17
11template<int64_t N> requires(N > 1) 18template<auto N>
19requires(N > 1) && std::is_integral_v<decltype(N)>
12class Zmod { 20class Zmod {
13public: 21public:
14 Zmod(int64_t z) : int64{(z%N + N) % N} {} 22 Zmod(decltype(N) z) : value{(z%N + N) % N} {}
15 int64_t toint64() const { return int64; } 23 decltype(N) toint() const { return value; }
16 24
17 Zmod operator+(const Zmod& z) const { return int64 + z.int64; } 25 Zmod operator+(const Zmod& z) const { return value + z.value; }
18 Zmod operator-(const Zmod& z) const { return int64 - z.int64; } 26 Zmod operator-(const Zmod& z) const { return value - z.value; }
19 Zmod operator*(const Zmod& z) const { return int64 * z.int64; } 27 Zmod operator*(const Zmod& z) const { return value * z.value; }
20 Zmod operator+=(const Zmod& z) { return (*this) = int64 + z.int64; } 28 Zmod operator+=(const Zmod& z) { return (*this) = value + z.value; }
21 Zmod operator-=(const Zmod& z) { return (*this) = int64 - z.int64; } 29 Zmod operator-=(const Zmod& z) { return (*this) = value - z.value; }
22 Zmod operator*=(const Zmod& z) { return (*this) = int64 * z.int64; } 30 Zmod operator*=(const Zmod& z) { return (*this) = value * z.value; }
23 31
24 bool operator==(const Zmod& z) const { return int64 == z.int64; } 32 bool operator==(const Zmod& z) const { return value == z.value; }
25 bool operator!=(const Zmod& z) const { return int64 != z.int64; } 33 bool operator!=(const Zmod& z) const { return value != z.value; }
26 34
27 std::optional<Zmod> inverse() const { 35 std::optional<Zmod> inverse() const {
28 auto [g, a, _] = extended_gcd(int64, N); 36 auto [g, a, _] = extended_gcd(value, N);
29 return g == 1 ? Zmod(a) : std::optional<Zmod>{}; 37 return g == 1 ? Zmod(a) : std::optional<Zmod>{};
30 } 38 }
31 39
@@ -40,16 +48,10 @@ public:
40 } 48 }
41 49
42 friend std::ostream& operator<<(std::ostream& os, const Zmod<N>& z) { 50 friend std::ostream& operator<<(std::ostream& os, const Zmod<N>& z) {
43 return os << "(" << z.int64 << " mod " << N << ")"; 51 return os << "(" << z.value << " mod " << N << ")";
44 } 52 }
45private: 53private:
46 int64_t int64; 54 decltype(N) value;
47}; 55};
48 56
49std::tuple<int64_t, int64_t, int64_t> extended_gcd(int64_t a, int64_t b) {
50 if (b == 0) return {a, 1, 0};
51 auto [g, x, y] = extended_gcd(b, a%b);
52 return {g, y, x - y*(a/b)};
53}
54
55#endif 57#endif