diff --git a/internal/fp.lua b/internal/fp.lua
index 361659a..201c177 100644
--- a/internal/fp.lua
+++ b/internal/fp.lua
@@ -1,9 +1,125 @@
-local unpack = unpack or table.unpack
+--- Arithmetic on Curve25519's base field.
+--
+-- @module internal.fq
+--
+local unpack = unpack or table.unpack
+local bxor = bit32.bxor
+local band = bit32.band
+local bor = bit32.bor
+
+--- The modular square root of -1.
+local I = {
+ 0958640 * 2 ^ 0,
+ 0826664 * 2 ^ 22,
+ 1613251 * 2 ^ 43,
+ 1041528 * 2 ^ 64,
+ 0013673 * 2 ^ 85,
+ 0387171 * 2 ^ 107,
+ 1824679 * 2 ^ 128,
+ 0313839 * 2 ^ 149,
+ 0709440 * 2 ^ 170,
+ 0122635 * 2 ^ 192,
+ 0262782 * 2 ^ 213,
+ 0712905 * 2 ^ 234,
+}
+
+--- The difference between a non-canonical number and its canonical equivalent.
+local CDIFF = {
+ 2 ^ 22 - 19,
+ (2 ^ 21 - 1) * 2 ^ 22,
+ (2 ^ 21 - 1) * 2 ^ 43,
+ (2 ^ 21 - 1) * 2 ^ 64,
+ (2 ^ 22 - 1) * 2 ^ 85,
+ (2 ^ 21 - 1) * 2 ^ 107,
+ (2 ^ 21 - 1) * 2 ^ 128,
+ (2 ^ 21 - 1) * 2 ^ 149,
+ (2 ^ 22 - 1) * 2 ^ 170,
+ (2 ^ 21 - 1) * 2 ^ 192,
+ (2 ^ 21 - 1) * 2 ^ 213,
+ (2 ^ 21 - 1) * 2 ^ 234,
+}
+
+--- A base field polynomial.
+--
+-- The Curve25519 paper represents its numbers as "polynomals" that slice the
+-- bigint into a little-endian array of floats. Each float slice is such that
+-- the (infinite precision) sum of all of them is equal to the represented
+-- number.
+--
+-- For our implementation, we use an array of 12 floats. Each one has a specific
+-- exponent and mantissa range.
+--
+--
+--
+--
+-- Index | Coefficient Range | Multiplier |
+-- 0 | (-2²²..2²²) | 2⁰ |
+-- 1 | (-2²¹..2²¹) | 2²² |
+-- 2 | (-2²¹..2²¹) | 2⁴³ |
+-- 3 | (-2²¹..2²¹) | 2⁶⁴ |
+-- 4 | (-2²²..2²²) | 2⁸⁵ |
+-- 5 | (-2²¹..2²¹) | 2¹⁰⁷ |
+-- 6 | (-2²¹..2²¹) | 2¹²⁸ |
+-- 7 | (-2²¹..2²¹) | 2¹⁴⁹ |
+-- 8 | (-2²²..2²²) | 2¹⁷⁰ |
+-- 9 | (-2²¹..2²¹) | 2¹⁹² |
+-- 10 | (-2²¹..2²¹) | 2²¹³ |
+-- 11 | (-2²¹..2²¹) | 2²³⁴ |
+--
+--
+-- @type fq
+--
+local fq = nil
+if fq ~= nil then return end
+
+--- A nonnegative @{fq}.
+--
+-- This type represents elements that have no negative coefficients.
+--
+-- @type fqAbs
+--
+local fqAbs = nil
+if fqAbs ~= nil then return end
+
+--- An uncarried @{fq}.
+--
+-- This type represents elements that have coefficients in a wider range than
+-- the limits specified in @{fq}. Specifically, this represents all the results
+-- of uncarried float-wise additions of two elements.
+--
+-- @type fqUncarried
+--
+local fqUncarried = nil
+if fqUncarried ~= nil then return end
+
+--- Converts a Lua number to an element.
+--
+-- @tparam number n A number n in [0..2²²).
+-- @treturn fqAbs n as a base field element.
+--
local function num(n)
return {n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
end
+--- Adds two elements.
+--
+-- @tparam fq a
+-- @tparam fq b
+-- @treturn fqUncarried
+--
local function add(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
@@ -23,6 +139,42 @@ local function add(a, b)
}
end
+--- Negates an element.
+--
+-- @tparam fq a
+-- @treturn fq
+--
+local function neg(a)
+ local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
+ return {
+ -a00,
+ -a01,
+ -a02,
+ -a03,
+ -a04,
+ -a05,
+ -a06,
+ -a07,
+ -a08,
+ -a09,
+ -a10,
+ -a11,
+ }
+end
+
+--- Subtracts an element from another.
+--
+-- If both elements are positive, then the result can be guaranteed to fit in
+-- a single @{fq} without needing any carrying.
+--
+-- @tparam[1] fq a
+-- @tparam[1] fq b
+-- @treturn[1] fqUncarried
+--
+-- @tparam[2] fqAbs a
+-- @tparam[2] fqAbs b
+-- @treturn[2] fq
+--
local function sub(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
@@ -42,50 +194,103 @@ local function sub(a, b)
}
end
-local function kmul(a, k)
+--- Carries an element.
+--
+-- @tparam fqUncarried a
+-- @treturn fqAbs
+--
+local function carry(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
- local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11
+ local r00, r01, r02, r03, r04, r05, r06, r07, r08, r09, r10, r11
- -- Multiply.
- c00 = a00 * k
- c01 = a01 * k
- c02 = a02 * k
- c03 = a03 * k
- c04 = a04 * k
- c05 = a05 * k
- c06 = a06 * k
- c07 = a07 * k
- c08 = a08 * k
- c09 = a09 * k
- c10 = a10 * k
- c11 = a11 * k
+ r11 = a11 % 2 ^ 255
+ a00 = a00 + (a11 - r11) * (19 / 2 ^ 255)
- -- Carry and reduce.
- a11 = c11 % 2 ^ 255 c00 = c00 + (c11 - a11) * (19 / 2 ^ 255)
+ r00 = a00 % 2 ^ 22 a01 = a01 + (a00 - r00)
+ r01 = a01 % 2 ^ 43 a02 = a02 + (a01 - r01)
+ r02 = a02 % 2 ^ 64 a03 = a03 + (a02 - r02)
+ r03 = a03 % 2 ^ 85 a04 = a04 + (a03 - r03)
+ r04 = a04 % 2 ^ 107 a05 = a05 + (a04 - r04)
+ r05 = a05 % 2 ^ 128 a06 = a06 + (a05 - r05)
+ r06 = a06 % 2 ^ 149 a07 = a07 + (a06 - r06)
+ r07 = a07 % 2 ^ 170 a08 = a08 + (a07 - r07)
+ r08 = a08 % 2 ^ 192 a09 = a09 + (a08 - r08)
+ r09 = a09 % 2 ^ 213 a10 = a10 + (a09 - r09)
+ r10 = a10 % 2 ^ 234 a11 = r11 + (a10 - r10)
- a00 = c00 % 2 ^ 22 c01 = c01 + (c00 - a00)
- a01 = c01 % 2 ^ 43 c02 = c02 + (c01 - a01)
- a02 = c02 % 2 ^ 64 c03 = c03 + (c02 - a02)
- a03 = c03 % 2 ^ 85 c04 = c04 + (c03 - a03)
- a04 = c04 % 2 ^ 107 c05 = c05 + (c04 - a04)
- a05 = c05 % 2 ^ 128 c06 = c06 + (c05 - a05)
- a06 = c06 % 2 ^ 149 c07 = c07 + (c06 - a06)
- a07 = c07 % 2 ^ 170 c08 = c08 + (c07 - a07)
- a08 = c08 % 2 ^ 192 c09 = c09 + (c08 - a08)
- a09 = c09 % 2 ^ 213 c10 = c10 + (c09 - a09)
- a10 = c10 % 2 ^ 234 c11 = a11 + (c10 - a10)
+ r11 = a11 % 2 ^ 255 r00 = r00 + (a11 - r11) * (19 / 2 ^ 255)
- a11 = c11 % 2 ^ 255 a00 = a00 + (c11 - a11) * (19 / 2 ^ 255)
-
- return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11}
+ return {r00, r01, r02, r03, r04, r05, r06, r07, r08, r09, r10, r11}
end
+--- Returns whether the modp number is the canonical representative.
+--
+-- @see canonicalize
+--
+-- @tparam fqAbs a
+-- @treturn boolean
+--
+local function isCanonical(a)
+ local e11 = bxor(a[12] * 2 ^ -234, 2 ^ 21 - 1)
+ local e10 = bxor(a[11] * 2 ^ -213, 2 ^ 21 - 1)
+ local e09 = bxor(a[10] * 2 ^ -192, 2 ^ 21 - 1)
+ local e08 = bxor(a[09] * 2 ^ -170, 2 ^ 22 - 1)
+ local e07 = bxor(a[08] * 2 ^ -149, 2 ^ 21 - 1)
+ local e06 = bxor(a[07] * 2 ^ -128, 2 ^ 21 - 1)
+ local e05 = bxor(a[06] * 2 ^ -107, 2 ^ 21 - 1)
+ local e04 = bxor(a[05] * 2 ^ -85, 2 ^ 22 - 1)
+ local e03 = bxor(a[04] * 2 ^ -64, 2 ^ 21 - 1)
+ local e02 = bxor(a[03] * 2 ^ -43, 2 ^ 21 - 1)
+ local e01 = bxor(a[02] * 2 ^ -22, 2 ^ 21 - 1)
+ local e00 = band(a[01] - (2 ^ 22 - 19), 2 ^ 31)
+ return 0 ~= bor(e00, e01, e02, e03, e04, e05, e06, e07, e08, e09, e10, e11)
+end
+
+--- Returns the canoncal representative of a modp number.
+--
+-- Some elements can be represented by two different arrays of floats. This
+-- returns the canonical element of the represented equivalence class. We define
+-- an element as canonical if it's the smallest nonnegative number in its class.
+--
+-- @tparam fq a
+-- @treturn fqAbs
+--
+local function canonicalize(a)
+ a = carry(a)
+ local zero = num(0)
+ local diff = isCanonical(a) and zero or CDIFF
+ return sub(a, diff)
+end
+
+--- Returns whether two elements are the same.
+--
+-- @tparam fqAbs a
+-- @tparam fqAbs b
+-- @treturn boolean
+--
+local function eq(a, b)
+ a = canonicalize(a)
+ b = canonicalize(b)
+ for i = 1, 12 do
+ if a[i] ~= b[i] then
+ return false
+ end
+ end
+ return true
+end
+
+--- Multiplies two elements.
+--
+-- @tparam fqUncarried a
+-- @tparam fqUncarried b
+-- @treturn fqAbs
+--
local function mul(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11
- -- Multiply high half.
+ -- Multiply high half into c00..c11.
c00 = a11 * b01
+ a10 * b02
+ a09 * b03
@@ -153,7 +358,7 @@ local function mul(a, b)
+ a10 * b11
c10 = a11 * b11
- -- Multiply low half with reduction.
+ -- Multiply low half with reduction into c00..c11.
c00 = c00 * (19 / 2 ^ 255)
+ a00 * b00
c01 = c01 * (19 / 2 ^ 255)
@@ -265,6 +470,11 @@ local function mul(a, b)
return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11}
end
+--- Squares an element.
+--
+-- @tparam fqUncarried a
+-- @treturn fqAbs
+--
local function square(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local d00, d01, d02, d03, d04, d05, d06, d07, d08, d09, d10
@@ -283,7 +493,7 @@ local function square(a)
d09 = a09 + a09
d10 = a10 + a10
- -- Multiply high half.
+ -- Multiply high half into c00..c11.
c00 = a11 * d01
+ a10 * d02
+ a09 * d03
@@ -321,7 +531,7 @@ local function square(a)
c09 = a11 * d10
c10 = a11 * a11
- -- Multiply low half with reduction.
+ -- Multiply low half with reduction into c00..c11.
c00 = c00 * (19 / 2 ^ 255)
+ a00 * a00
c01 = c01 * (19 / 2 ^ 255)
@@ -397,15 +607,56 @@ local function square(a)
return {a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11}
end
+--- Multiplies an element by a number.
+--
+-- @tparam fqUncarried a
+-- @tparam number k A number k in [0..2²¹).
+-- @treturn fqAbs
+--
+local function kmul(a, k)
+ local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
+
+ -- TODO WHY ARE TYPE CONSTRAINTS SO DIFFICULT TO SPECIFY
+ return carry {
+ a00 * k,
+ a01 * k,
+ a02 * k,
+ a03 * k,
+ a04 * k,
+ a05 * k,
+ a06 * k,
+ a07 * k,
+ a08 * k,
+ a09 * k,
+ a10 * k,
+ a11 * k,
+ }
+end
+
+--- Squares a modp number n times.
+--
+-- @tparam fqUncarried a
+-- @tparam number n
+-- @treturn fqAbs
+--
local function nsquare(a, n)
for _ = 1, n do a = square(a) end
return a
end
+--- Computes the inverse of an element.
+--
+-- Computation of the inverse requires 11 multiplicationss and 252 squarings.
+--
+-- @tparam fqUncarried a
+-- @treturn[1] fqAbs a⁻¹
+-- @treturn[2] fqAbs 0 if the argument is 0, which has no inverse.
+--
local function invert(a)
local a2 = square(a)
local a9 = mul(a, nsquare(a2, 2))
local a11 = mul(a9, a2)
+
local x5 = mul(square(a11), a9)
local x10 = mul(nsquare(x5, 5), x5)
local x20 = mul(nsquare(x10, 10), x10)
@@ -414,44 +665,71 @@ local function invert(a)
local x100 = mul(nsquare(x50, 50), x50)
local x200 = mul(nsquare(x100, 100), x100)
local x250 = mul(nsquare(x200, 50), x50)
+
return mul(nsquare(x250, 5), a11)
end
-local function encode(a)
- local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
+--- Returns an element x that satisfies v * x² = u.
+--
+-- Note that when v = 0, the returned value can take any @{fqAbs} value.
+--
+-- @tparam fqUncarried u
+-- @tparam fqUncarried v
+-- @treturn[1] fqAbs x
+-- @treturn[2] nil if there is no solution.
+--
+local function sqrtDiv(u, v)
+ u = carry(u)
- -- Canonicalize.
- if a11 == (2 ^ 21 - 1) * 2 ^ 234
- and a10 == (2 ^ 21 - 1) * 2 ^ 213
- and a09 == (2 ^ 21 - 1) * 2 ^ 192
- and a08 == (2 ^ 22 - 1) * 2 ^ 170
- and a07 == (2 ^ 21 - 1) * 2 ^ 149
- and a06 == (2 ^ 21 - 1) * 2 ^ 128
- and a05 == (2 ^ 21 - 1) * 2 ^ 107
- and a04 == (2 ^ 22 - 1) * 2 ^ 85
- and a03 == (2 ^ 21 - 1) * 2 ^ 64
- and a02 == (2 ^ 21 - 1) * 2 ^ 43
- and a01 == (2 ^ 21 - 1) * 2 ^ 22
- and a00 >= 2 ^ 22 - 19
- then
- a11 = 0
- a10 = 0
- a09 = 0
- a08 = 0
- a07 = 0
- a06 = 0
- a05 = 0
- a04 = 0
- a03 = 0
- a02 = 0
- a01 = 0
- a00 = a00 - (2 ^ 22 - 19)
+ local v2 = square(v)
+ local v3 = mul(v, v2)
+ local v6 = square(v3)
+ local v7 = mul(v, v6)
+ local uv7 = mul(u, v7)
+
+ local x2 = mul(square(uv7), uv7)
+ local x4 = mul(nsquare(x2, 2), x2)
+ local x8 = mul(nsquare(x4, 4), x4)
+ local x16 = mul(nsquare(x8, 8), x8)
+ local x18 = mul(nsquare(x16, 2), x2)
+ local x32 = mul(nsquare(x16, 16), x16)
+ local x50 = mul(nsquare(x32, 18), x18)
+ local x100 = mul(nsquare(x50, 50), x50)
+ local x200 = mul(nsquare(x100, 100), x100)
+ local x250 = mul(nsquare(x200, 50), x50)
+ local pr = mul(nsquare(x250, 2), uv7)
+
+ local uv3 = mul(u, v3)
+ local b = mul(uv3, pr)
+ local b2 = square(b)
+ local vb2 = mul(v, b2)
+
+ if not eq(vb2, u) then
+ -- Found sqrt(-u/v), multiply by i.
+ b = mul(b, I)
+ b2 = square(b)
+ vb2 = mul(v, b2)
end
- -- Encode.
- -- TODO this can be improved.
+ if eq(vb2, u) then
+ return b
+ else
+ return nil
+ end
+end
+
+--- Encodes an element in little-endian.
+--
+-- @tparam fqAbs a
+-- @treturn string A 32-byte string. Always represents the canonical element.
+--
+local function encode(a)
+ a = canonicalize(a)
+ local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
+
local bytes = {}
local acc = a00
+
local function putBytes(n)
for _ = 1, n do
local byte = acc % 256
@@ -476,13 +754,18 @@ local function encode(a)
return string.char(unpack(bytes))
end
+--- Decodes an element in little-endian.
+--
+-- @tparam string b A 32-byte string. The most-significant bit is discarded.
+-- @treturn fqAbs The decoded element. May not be canonical.
+--
local function decode(b)
local w00, w01, w02, w03, w04, w05, w06, w07, w08, w09, w10, w11 =
("