ccryptolib/internal/curve25519.lua
2022-04-08 11:56:03 -03:00

275 lines
8.1 KiB
Lua

--- Point arithmetic on the Curve25519 Montgomery curve.
--
-- :::note Internal Module
-- This module is meant for internal use within the library. Its API is unstable
-- and subject to change without major version bumps.
-- :::
--
-- <br />
--
-- @module[kind=internal] internal.curve25519
--
local fp = require "ccryptolib.internal.fp"
local ed = require "ccryptolib.internal.edwards25519"
local random = require "ccryptolib.random"
local function double(P1)
local x1, z1 = P1[1], P1[2]
local a = fp.add(x1, z1)
local aa = fp.square(a)
local b = fp.sub(x1, z1)
local bb = fp.square(b)
local c = fp.sub(aa, bb)
local x3 = fp.mul(aa, bb)
local z3 = fp.mul(c, fp.add(bb, fp.kmul(c, 121666)))
return {x3, z3}
end
local function dadd(DP, P1, P2)
local dx, dz = DP[1], DP[2]
local x1, z1 = P1[1], P1[2]
local x2, z2 = P2[1], P2[2]
local a = fp.add(x1, z1)
local b = fp.sub(x1, z1)
local c = fp.add(x2, z2)
local d = fp.sub(x2, z2)
local da = fp.mul(d, a)
local cb = fp.mul(c, b)
local x3 = fp.mul(dz, fp.square(fp.add(da, cb)))
local z3 = fp.mul(dx, fp.square(fp.sub(da, cb)))
return {x3, z3}
end
--- Performs a step on the Montgomery ladder.
--
-- @param C A - B.
-- @param A The first point.
-- @param B The second point.
-- @return 2A
-- @return A + B
--
local function step(DP, P1, P2)
local dx, dz = DP[1], DP[2]
local x1, z1 = P1[1], P1[2]
local x2, z2 = P2[1], P2[2]
local a = fp.add(x1, z1)
local aa = fp.square(a)
local b = fp.sub(x1, z1)
local bb = fp.square(b)
local e = fp.sub(aa, bb)
local c = fp.add(x2, z2)
local d = fp.sub(x2, z2)
local da = fp.mul(d, a)
local cb = fp.mul(c, b)
local x4 = fp.mul(dz, fp.square(fp.add(da, cb)))
local z4 = fp.mul(dx, fp.square(fp.sub(da, cb)))
local x3 = fp.mul(aa, bb)
local z3 = fp.mul(e, fp.add(bb, fp.kmul(e, 121666)))
return {x3, z3}, {x4, z4}
end
local function ladder(DP, bits)
local P = {fp.num(1), fp.num(0)}
local Q = DP
for i = #bits, 1, -1 do
if bits[i] == 0 then
P, Q = step(DP, P, Q)
else
Q, P = step(DP, Q, P)
end
end
return P
end
--- Performs a scalar multiplication operation with multiplication by 8.
--
-- @tparam point P The base point.
-- @tparam {number...} bits The scalar multiplier, in little-endian bits.
-- @treturn point The product, multiplied by 8.
--
local function ladder8(P, bits)
-- Randomize.
local rf = fp.decode(random.random(32))
P = {fp.mul(P[1], rf), fp.mul(P[2], rf)}
-- Multiply.
return double(double(double(ladder(P, bits))))
end
local function scale(P)
return {fp.mul(P[1], fp.invert(P[2])), fp.num(1)}
end
--- Encodes a point.
--
-- @tparam point P1 The scaled point to encode.
-- @treturn string The 32-byte encoded point.
--
local function encode(P)
return fp.encode(P[1])
end
--- Decodes a point.
--
-- @tparam string str A 32-byte encoded point.
-- @treturn point The decoded point.
--
local function decode(str)
return {fp.decode(str), fp.num(1)}
end
--- Performs a scalar multiplication by the base point G.
--
-- @tparam {number...} bits The scalar multiplier, in little-endian bits.
-- @return The product point.
--
local function mulG(bits)
-- Multiply by G on Edwards25519.
local P = ed.mulG(bits)
-- Use the birational map to get the point on Curve25519.
-- Never fails since G is in the large group, and the exponent is clamped.
local Py, Pz = P[2], P[3]
local Rx = fp.carry(fp.add(Py, Pz))
local Rz = fp.carry(fp.sub(Pz, Py))
return {Rx, Rz}
end
--- Computes a twofold product from a ruleset.
--
-- @tparam point P The base point.
-- @tparam {{number...}, {number...}} The ruleset generated by scalars m, n.
-- @treturn point [8m]P
-- @treturn point [8n]P
-- @treturn point [8m]P - [8n]P
--
local function prac(P, ruleset)
-- Randomize.
local rf = fp.decode(random.random(32))
local A = {fp.mul(P[1], rf), fp.mul(P[2], rf)}
-- Start the base at [8]P.
local A = double(double(double(A)))
-- Throw away small order points.
if fp.eqz(fp.mul(A[1], A[2])) then return A end
-- Now e = d = gcd(m, n) / 8.
-- Update A from [8]P to [8u]P.
A = ladder(A, ruleset[1])
-- Reject rulesets where m = n.
local rules = ruleset[2]
if #rules == 0 then return nil end
-- Evaluate the first rule.
-- Since e = d, this means A - B = C = O. Differential addition fails when
-- C = O, so we need to treat this case specially.
-- Note that rules 0 and 1 never happen last, since the algorithm would stop
-- one step earlier if they did.
local B, C
local rule = rules[#rules]
if rule == 2 then
-- (A, B, C) ← (2A + B, B, 2A) = (3A, A, 2A)
local A2 = double(A)
A, B, C = dadd(A, A2, A), A, A2
elseif rule == 3 or rule == 5 then
-- (A, B, C) ← (A + B, B, A) = (2A, A, A)
-- or (A, B, C) ← (2A, B, 2A - B) = (2A, A, A)
A, B, C = double(A), A, A
elseif rule == 6 then
-- (A, B, C) ← (3A + 3B, B, 3A + 2B) = (6A, A, 5A)
local A2 = double(A)
local A3 = dadd(A, A2, A)
A, B, C = double(A3), A, dadd(A, A3, A2)
elseif rule == 7 then
-- (A, B, C) ← (3A + 2B, B, 3A + B) = (5A, A, 4A)
local A2 = double(A)
local A3 = dadd(A, A2, A)
local A4 = double(A2)
A, B, C = dadd(A3, A4, A), A, A4
elseif rule == 8 then
-- (A, B, C) ← (3A + B, B, 3A) = (4A, A, 3A)
local A2 = double(A)
local A3 = dadd(A, A2, A)
A, B, C = double(A2), A, A3
else
-- (A, B, C) ← (A, 2B, A - 2B) = (A, 2A, A)
A, B, C = A, double(A), A
end
-- Evaluate the other rules.
-- Let's assume addition is undefined here, this happens when A - B = 0.
-- Since A = [d]P and B = [e]P, A = B happens when:
-- (1) P is on the large order base group and d ≡ e (mod q).
-- (2) P is on the large order twist group and d ≡ e (mod q').
-- (3) P is on a small order group.
-- Case (3) never happens since we throw small order points away above.
-- Since 0 ≤ {d, e} < q < q', a modular equivalence here means an integer
-- equivalence. Therefore d = e.
-- However, the ruleset stops when d = e, therefore the algorithm must have
-- stopped earlier than when it did. Contradiction.
-- Therefore, addition is always defined.
-- Furthermore, the PRAC invariants mean that this product is the same as
-- if the points were multiplied separately.
for i = #rules - 1, 1, -1 do
local rule = rules[i]
if rule == 0 then
-- (A, B, C) ← (B, A, B - A)
A, B = B, A
elseif rule == 1 then
-- (A, B, C) ← (2A + B, A + 2B, A - B)
local AB = dadd(C, A, B)
A, B = dadd(B, AB, A), dadd(A, AB, B)
elseif rule == 2 then
-- (A, B, C) ← (2A + B, B, 2A)
A, C = dadd(B, dadd(C, A, B), A), double(A)
elseif rule == 3 then
-- (A, B, C) ← (A + B, B, A)
A, C = dadd(C, A, B), A
elseif rule == 5 then
-- (A, B, C) ← (2A, B, 2A - B)
A, C = double(A), dadd(B, A, C)
elseif rule == 6 then
-- (A, B, C) ← (3A + 3B, B, 3A + 2B)
local AB = dadd(C, A, B)
local AABB = double(AB)
A, C = dadd(AB, AABB, AB), dadd(dadd(A, AB, B), AABB, A)
elseif rule == 7 then
-- (A, B, C) ← (3A + 2B, B, 3A + B)
local AB = dadd(C, A, B)
local AAB = dadd(B, AB, A)
A, C = dadd(A, AAB, AB), dadd(AB, AAB, A)
elseif rule == 8 then
-- (A, B, C) ← (3A + B, B, 3A)
local AA = double(A)
A, C = dadd(C, AA, dadd(C, A, B)), dadd(A, AA, A)
else
-- (A, B, C) ← (A, 2B, A - 2B)
B, C = double(B), dadd(A, C, B)
end
end
return A, B, C
end
local function fieldMul(P, m)
return {fp.mul(P[1], fp.decode(m)), P[2]}
end
return {
G = {fp.num(9), fp.num(1)},
dadd = dadd,
scale = scale,
encode = encode,
decode = decode,
ladder8 = ladder8,
mulG = mulG,
prac = prac,
fieldMul = fieldMul,
}