ccryptolib/internal/fq.lua
2022-04-08 11:56:03 -03:00

395 lines
11 KiB
Lua
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

--- Arithmetic on Curve25519's scalar field.
--
-- :::note Internal Module
-- This module is meant for internal use within the library. Its API is unstable
-- and subject to change without major version bumps.
-- :::
--
-- <br />
--
-- @module[kind=internal] internal.fq
--
local mp = require "ccryptolib.internal.mp"
local util = require "ccryptolib.internal.util"
local unpack = unpack or table.unpack
--- The scalar field's order, q.
local Q = {
16110573,
06494812,
14047250,
10680220,
14612958,
00000020,
00000000,
00000000,
00000000,
00000000,
00004096,
}
--- The first Montgomery precomputed constant, -q⁻¹ mod 2²⁶⁴.
local T0 = {
05537307,
01942290,
16765621,
16628356,
10618610,
07072433,
03735459,
01369940,
15276086,
13038191,
13409718,
}
--- The second Montgomery precomputed constant, 2⁵²⁸ mod q.
local T1 = {
11711996,
01747860,
08326961,
03814718,
01859974,
13327461,
16105061,
07590423,
04050668,
08138906,
00000283,
}
local T8 = {
01130678,
05563041,
03870191,
01622646,
01247520,
12151703,
16693196,
09337410,
04700637,
07308819,
00002083,
}
local ZERO = mp.num(0)
--- Reduces a number modulo q.
--
-- @tparam {number...} a A number a < 2q as 11 limbs in [0..2²⁵).
-- @treturn {number...} a mod q as 11 limbs in [0..2²⁴).
--
local function reduce(a)
local c = mp.sub(a, Q)
-- Return carry(a) if a < q.
if mp.approx(c) < 0 then return mp.carry(a) end
-- c >= q means c - q >= 0.
-- Since q < 2²⁸⁸, c < 2q means c - q < q < 2²⁸⁸.
-- c's limbs fit in (-2²⁶..2²⁶), since subtraction adds at most one bit.
local cc = mp.carry(c)
cc[12] = nil -- cc < q implies that cc[12] = 0.
return cc
end
--- Adds two scalars mod q.
--
-- If the two operands are in Montgomery form, returns the correct result also
-- in Montgomery form, since (2²⁶⁴ × a) + (2²⁶⁴ × b) ≡ 2²⁶⁴ × (a + b) (mod q).
--
-- @tparam {number...} a A number a < q as 11 limbs in [0..2²⁴).
-- @tparam {number...} b A number b < q as 11 limbs in [0..2²⁴).
-- @treturn {number...} a + b mod q as 11 limbs in [0..2²⁴).
--
local function add(a, b)
return reduce(mp.add(a, b))
end
--- Negates a scalar mod q.
--
-- @tparam {number...} a A number a < q as 11 limbs in [0..2²⁴).
-- @treturn {number...} -a mod q as 11 limbs in [0..2²⁴).
--
local function neg(a)
return reduce(mp.sub(Q, a))
end
--- Subtracts scalars mod q.
--
-- If the two operands are in Montgomery form, returns the correct result also
-- in Montgomery form, since (2²⁶⁴ × a) - (2²⁶⁴ × b) ≡ 2²⁶⁴ × (a - b) (mod q).
--
-- @tparam {number...} a A number a < q as 11 limbs in [0..2²⁴).
-- @tparam {number...} b A number b < q as 11 limbs in [0..2²⁴).
-- @treturn {number...} a - b mod q as 11 limbs in [0..2²⁴).
--
local function sub(a, b)
return add(a, neg(b))
end
--- Given two scalars a and b, computes 2⁻²⁶⁴ × a × b mod q.
--
-- @tparam {number...} a A number a as 11 limbs in [0..2²⁴).
-- @tparam {number...} b A number b < q as 11 limbs in [0..2²⁴).
-- @treturn 2⁻²⁶⁴ × a × b mod q as 11 limbs in [0..2²⁴).
--
local function mul(a, b)
local t0, t1 = mp.mul(a, b)
local mq0, mq1 = mp.mul(mp.lmul(t0, T0), Q)
local _, s1 = mp.dwadd(t0, t1, mq0, mq1)
return reduce(s1)
end
--- Converts a scalar into Montgomery form.
--
-- @tparam {number...} a A number a as 11 limbs in [0..2²⁴).
-- @treturn {number...} 2²⁶⁴ × a mod q as 11 limbs in [0..2²⁴).
--
local function montgomery(a)
-- 0 ≤ a < 2²⁶⁴ and 0 ≤ T1 < q.
return mul(a, T1)
end
--- Converts a scalar from Montgomery form.
--
-- @tparam {number...} a A number a < q as 11 limbs in [0..2²⁴).
-- @treturn {number...} 2⁻²⁶⁴ × a mod q as 11 limbs in [0..2²⁴).
--
local function demontgomery(a)
-- It's REDC all over again except b is 1.
local mq0, mq1 = mp.mul(mp.lmul(a, T0), Q)
local _, s1 = mp.dwadd(a, ZERO, mq0, mq1)
return reduce(s1)
end
--- Converts a Lua number to a scalar.
--
-- @tparam number n A number n in [0..2²⁴).
-- @treturn {number...} 2²⁶⁴ × n mod q as 11 limbs in [0..2²⁴).
--
local function num(n)
return montgomery({n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
end
--- Encodes a scalar.
--
-- @tparam {number...} a A number 2²⁶⁴ × a mod q as 11 limbs in [0..2²⁴).
-- @treturn string The 32-byte string encoding of a.
--
local function encode(a)
return ("<I3I3I3I3I3I3I3I3I3I3I2"):pack(unpack(demontgomery(a)))
end
--- Decodes a scalar.
--
-- @tparam string str A 32-byte string encoding some little-endian number a.
-- @treturn {number...} 2²⁶⁴ × a mod q as 11 limbs in [0..2²⁴).
--
local function decode(str)
local dec = {("<I3I3I3I3I3I3I3I3I3I3I2"):unpack(str)} dec[12] = nil
return montgomery(dec)
end
--- Decodes a scalar from a "wide" string.
--
-- @tparam string str A 64-byte string encoding some little-endian number a.
-- @treturn {number...} 2²⁶⁴ × a mod q as 11 limbs in [0..2²⁴).
--
local function decodeWide(str)
local low = {("<I3I3I3I3I3I3I3I3I3I3I3"):unpack(str)} low[12] = nil
local high = {("<I3I3I3I3I3I3I3I3I3I3I1"):unpack(str, 34)} high[12] = nil
return add(montgomery(low), montgomery(montgomery(high)))
end
--- Decodes a scalar using the X25519/Ed25519 bit clamping scheme.
--
-- @tparam string str A 32-byte string encoding some little-endian number a.
-- @treturn {number...} 2²⁶⁴ × clamp(a) mod q as 11 limbs in [0..2²⁴).
--
local function decodeClamped(str)
-- Decode.
local words = {("<I3I3I3I3I3I3I3I3I3I3I2"):unpack(str)} words[12] = nil
-- Clamp.
words[1] = bit32.band(words[1], 0xfffff8)
words[11] = bit32.band(words[11], 0x7fff)
words[11] = bit32.bor(words[11], 0x4000)
return montgomery(words)
end
--- Decodes a scalar using the X25519/Ed25519 bit clamping scheme and division
-- by 8.
--
-- @tparam string str A 32-byte string encoding some little-endian number a.
-- @treturn {number...} 2²⁶⁴ × clamp(a) ÷ 8 mod q as 11 limbs in [0..2²⁴).
--
local function decodeClamped8(str)
-- Decode.
local words = {("<I3I3I3I3I3I3I3I3I3I3I2"):unpack(str)} words[12] = nil
-- Clamp.
words[1] = bit32.band(words[1], 0xfffff8)
words[11] = bit32.band(words[11], 0x7fff)
words[11] = bit32.bor(words[11], 0x4000)
return mul(words, T8)
end
--- Returns a scalar in binary.
--
-- @tparam {number...} a A number a < q as 11 limbs in [0..2²⁴).
-- @treturn {number...} 2⁻²⁶⁴ × a mod q as 265 bits.
--
local function bits(a)
return util.rebaseLE(demontgomery(a), 2 ^ 24, 2)
end
--- Makes a PRAC ruleset from a pair of scalars.
--
-- @tparam {number...} a A scalar a < q as 11 limbs in [0..2²⁴).
-- @tparam {number...} b A scalar b < q as 11 limbs in [0..2²⁴).
-- @treturn {{number...}, {number...}} The generated ruleset.
--
local function makeRuleset(a, b)
-- The numbers in raw multiprecision tables.
local dt = demontgomery(a) -- (-2²⁴..2²⁴)
local et = demontgomery(b) -- (-2²⁴..2²⁴)
local ft = mp.sub(dt, et) -- (-2²⁵..2²⁵)
-- Residue classes of (d, e) modulo 2.
local d2 = mp.mod2(dt)
local e2 = mp.mod2(et)
-- Residue classes of (d, e) modulo 3.
local d3 = mp.mod3(dt)
local e3 = mp.mod3(et)
-- (e, d - e) in limited-precision floating-point numbers.
local ef = mp.approx(et)
local ff = mp.approx(ft)
-- Lookup table for inversions and halvings modulo 3.
local lut3 = {[0] = 0, 2, 1}
local rules = {}
while ff ~= 0 do
if ff < 0 then
-- M0.
rules[#rules + 1] = 0
-- (d, e) ← (e, d)
dt, et = et, dt
d2, e2 = e2, d2
d3, e3 = e3, d3
ef = mp.approx(et)
ft = mp.sub(dt, et)
ff = -ff
elseif 4 * ff < ef and d3 == lut3[e3] then
-- M1.
rules[#rules + 1] = 1
-- (d, e) ← ((2d - e)/3, (2e - d)/3)
dt, et = mp.third(mp.add(dt, ft)), mp.third(mp.sub(et, ft))
d2, e2 = e2, d2
d3, e3 = mp.mod3(dt), mp.mod3(et)
ef = mp.approx(et)
elseif 4 * ff < ef and d2 == e2 and d3 == e3 then
-- M2.
rules[#rules + 1] = 2
-- (d, e) ← ((d - e)/2, e)
dt = mp.half(ft)
d2 = mp.mod2(dt)
d3 = lut3[(d3 - e3) % 3]
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif ff < 3 * ef then
-- M3.
rules[#rules + 1] = 3
-- (d, e) ← (d - e, e)
dt = mp.carryWeak(ft)
d2 = (d2 - e2) % 2
d3 = (d3 - e3) % 3
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d2 == e2 then
-- M4 (same as M2).
rules[#rules + 1] = 2
-- (d, e) ← ((d - e)/2, e)
dt = mp.half(ft)
d2 = mp.mod2(dt)
d3 = lut3[(d3 - e3) % 3]
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d2 == 0 then
-- M5.
rules[#rules + 1] = 5
-- (d, e) ← (d/2, e)
dt = mp.half(dt)
d2 = mp.mod2(dt)
d3 = lut3[d3]
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d3 == 0 then
-- M6.
rules[#rules + 1] = 6
-- (d, e) ← (d/3 - e, e)
dt = mp.carryWeak(mp.sub(mp.third(dt), et))
d2 = (d2 - e2) % 2
d3 = mp.mod3(dt)
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d3 == lut3[e3] then
-- M7.
rules[#rules + 1] = 7
-- (d, e) ← ((d - 2e)/3, e)
dt = mp.third(mp.sub(ft, et))
d3 = mp.mod3(dt)
ft = mp.sub(dt, et)
ff = mp.approx(ft)
elseif d3 == e3 then
-- M8.
rules[#rules + 1] = 8
-- (d, e) ← ((d - e)/3, e)
dt = mp.third(ft)
d2 = (d2 - e2) % 2
d3 = mp.mod3(dt)
ft = mp.sub(dt, et)
ff = mp.approx(ft)
else
-- M9.
rules[#rules + 1] = 9
-- (d, e) ← (d, e/2)
et = mp.half(et)
e2 = mp.mod2(et)
e3 = lut3[e3]
ef = mp.approx(et)
ft = mp.sub(dt, et)
ff = mp.approx(ft)
end
end
local ubits = util.rebaseLE(dt, 2 ^ 24, 2)
while ubits[#ubits] == 0 do ubits[#ubits] = nil end
return {ubits, rules}
end
return {
num = num,
add = add,
neg = neg,
sub = sub,
montgomery = montgomery,
demontgomery = demontgomery,
mul = mul,
encode = encode,
decode = decode,
decodeWide = decodeWide,
decodeClamped8 = decodeClamped8,
decodeClamped = decodeClamped,
bits = bits,
makeRuleset = makeRuleset,
}