diff --git a/ed25519.lua b/ed25519.lua
index bb067d2..1385611 100644
--- a/ed25519.lua
+++ b/ed25519.lua
@@ -4,231 +4,11 @@
--
local expect = require "cc.expect".expect
-local fp = require "ccryptolib.internal.fp"
local fq = require "ccryptolib.internal.fq"
local sha512 = require "ccryptolib.internal.sha512"
+local ed = require "ccryptolib.internal.edwards25519"
local random = require "ccryptolib.random"
-local unpack = unpack or table.unpack
-
-local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666)))
-local K = fp.kmul(D, 2)
-
-local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)}
-local G = nil
-
-local function double(P1)
- -- Unsoundness: fp.sub(g, e), and fp.sub(d, i) break fp.sub's contract since
- -- it doesn't accept an fp2. Although not ideal, in practice this doesn't
- -- matter since fp.carry handles the larger sum.
- local P1x, P1y, P1z = unpack(P1)
- local a = fp.square(P1x)
- local b = fp.square(P1y)
- local c = fp.square(P1z)
- local d = fp.add(c, c)
- local e = fp.add(a, b)
- local f = fp.add(P1x, P1y)
- local g = fp.square(f)
- local h = fp.carry(fp.sub(g, e))
- local i = fp.sub(b, a)
- local j = fp.carry(fp.sub(d, i))
- local P3x = fp.mul(h, j)
- local P3y = fp.mul(i, e)
- local P3z = fp.mul(j, i)
- local P3t = fp.mul(h, e)
- return {P3x, P3y, P3z, P3t}
-end
-
-local function add(P1, N1)
- local P1x, P1y, P1z, P1t = unpack(P1)
- local N1p, N1m, N1z, N1t = unpack(N1)
- local a = fp.sub(P1y, P1x)
- local b = fp.mul(a, N1m)
- local c = fp.add(P1y, P1x)
- local d = fp.mul(c, N1p)
- local e = fp.mul(P1t, N1t)
- local f = fp.mul(P1z, N1z)
- local g = fp.sub(d, b)
- local h = fp.sub(f, e)
- local i = fp.add(f, e)
- local j = fp.add(d, b)
- local P3x = fp.mul(g, h)
- local P3y = fp.mul(i, j)
- local P3z = fp.mul(h, i)
- local P3t = fp.mul(g, j)
- return {P3x, P3y, P3z, P3t}
-end
-
-local function sub(P1, N1)
- local P1x, P1y, P1z, P1t = unpack(P1)
- local N1p, N1m, N1z, N1t = unpack(N1)
- local a = fp.sub(P1y, P1x)
- local b = fp.mul(a, N1p)
- local c = fp.add(P1y, P1x)
- local d = fp.mul(c, N1m)
- local e = fp.mul(P1t, N1t)
- local f = fp.mul(P1z, N1z)
- local g = fp.sub(d, b)
- local h = fp.add(f, e)
- local i = fp.sub(f, e)
- local j = fp.add(d, b)
- local P3x = fp.mul(g, h)
- local P3y = fp.mul(i, j)
- local P3z = fp.mul(h, i)
- local P3t = fp.mul(g, j)
- return {P3x, P3y, P3z, P3t}
-end
-
-local function niels(P1)
- local P1x, P1y, P1z, P1t = unpack(P1)
- local N3p = fp.add(P1y, P1x)
- local N3m = fp.sub(P1y, P1x)
- local N3z = fp.add(P1z, P1z)
- local N3t = fp.mul(P1t, K)
- return {N3p, N3m, N3z, N3t}
-end
-
-local function scale(P1)
- local P1x, P1y, P1z = unpack(P1)
- local zInv = fp.invert(P1z)
- local P3x = fp.mul(P1x, zInv)
- local P3y = fp.mul(P1y, zInv)
- local P3z = fp.num(1)
- local P3t = fp.mul(P3x, P3y)
- return {P3x, P3y, P3z, P3t}
-end
-
-local function encode(P1)
- local P1x, P1y = unpack(P1)
- local y = fp.encode(P1y)
- local xBit = fp.canonicalize(P1x)[1] % 2
- return y:sub(1, -2) .. string.char(y:byte(-1) + xBit * 128)
-end
-
-local function decode(str)
- local P3y = fp.decode(str)
- local a = fp.square(P3y)
- local b = fp.sub(a, fp.num(1))
- local c = fp.mul(a, D)
- local d = fp.add(c, fp.num(1))
- local P3x = fp.sqrtDiv(b, d)
- if not P3x then return nil end
- local xBit = fp.canonicalize(P3x)[1] % 2
- if xBit ~= bit32.extract(str:byte(-1), 7) then
- P3x = fp.carry(fp.sub(fp.P, P3x))
- end
- local P3z = fp.num(1)
- local P3t = fp.mul(P3x, P3y)
- return {P3x, P3y, P3z, P3t}
-end
-
-G = decode("Xfffffffffffffffffffffffffffffff")
-
-local function signedRadixW(bits, w)
- -- TODO Find a more elegant way of doing this.
- local wPow = 2 ^ w
- local wPowh = wPow / 2
- local out = {}
- local acc = 0
- local mul = 1
- for i = 1, #bits do
- acc = acc + bits[i] * mul
- mul = mul * 2
- while i == #bits and acc > 0 or mul > wPow do
- local rem = acc % wPow
- if rem >= wPowh then rem = rem - wPow end
- acc = (acc - rem) / wPow
- mul = mul / wPow
- out[#out + 1] = rem
- end
- end
- return out
-end
-
-local function radixWTable(P, w)
- local out = {}
- for i = 1, 255 / w do
- local row = {niels(P)}
- for j = 2, 2 ^ w / 2 do
- P = add(P, row[1])
- row[j] = niels(P)
- end
- out[i] = row
- P = double(P)
- end
- return out
-end
-
-local G_W = 5
-local G_TABLE = radixWTable(G, G_W)
-
-local function WNAF(bits, w)
- -- TODO Find a more elegant way of doing this.
- local wPow = 2 ^ w
- local wPowh = wPow / 2
- local out = {}
- local acc = 0
- local mul = 1
- for i = 1, #bits do
- acc = acc + bits[i] * mul
- mul = mul * 2
- while i == #bits and acc > 0 or mul > wPow do
- if acc % 2 == 0 then
- acc = acc / 2
- mul = mul / 2
- out[#out + 1] = 0
- else
- local rem = acc % wPow
- if rem >= wPowh then rem = rem - wPow end
- acc = acc - rem
- out[#out + 1] = rem
- end
- end
- end
- while out[#out] == 0 do out[#out] = nil end
- return out
-end
-
-local function WNAFTable(P, w)
- local dP = double(P)
- local out = {niels(P)}
- for i = 3, 2 ^ w, 2 do
- out[i] = niels(add(dP, out[i - 2]))
- end
- return out
-end
-
-local function mulG(bits)
- local sw = signedRadixW(bits, G_W)
- local R = O
- for i = 1, #sw do
- local b = sw[i]
- if b > 0 then
- R = add(R, G_TABLE[i][b])
- elseif b < 0 then
- R = sub(R, G_TABLE[i][-b])
- end
- end
- return R
-end
-
-local function mul(P, bits)
- local naf = WNAF(bits, 5)
- local tbl = WNAFTable(P, 5)
- local R = O
- for i = #naf, 1, -1 do
- local b = naf[i]
- if b == 0 then
- R = double(R)
- elseif b > 0 then
- R = add(R, tbl[b])
- else
- R = sub(R, tbl[-b])
- end
- end
- return R
-end
-
local mod = {}
--- Computes a public key from a secret key.
@@ -243,7 +23,7 @@ function mod.publicKey(sk)
local h = sha512.digest(sk)
local x = fq.decodeClamped(h:sub(1, 32))
- return encode(scale(mulG(fq.bits(x))))
+ return ed.encode(ed.scale(ed.mulG(fq.bits(x))))
end
--- Signs a message.
@@ -266,8 +46,8 @@ function mod.sign(sk, pk, msg)
-- Commitment.
local k = fq.decodeWide(random.random(64))
- local r = mulG(fq.bits(k))
- local rStr = encode(scale(r))
+ local r = ed.mulG(fq.bits(k))
+ local rStr = ed.encode(ed.scale(r))
-- Challenge.
local e = fq.decodeWide(sha512.digest(rStr .. pk .. msg))
@@ -294,7 +74,7 @@ function mod.verify(pk, msg, sig)
expect(3, sig, "string")
assert(#sig == 64, "signature length must be 64")
- local y = decode(pk)
+ local y = ed.decode(pk)
if not y then return nil end
local rStr = sig:sub(1, 32)
@@ -302,11 +82,11 @@ function mod.verify(pk, msg, sig)
local e = fq.decodeWide(sha512.digest(rStr .. pk .. msg))
- local gs = mulG(fq.bits(fq.decode(sStr)))
- local ye = mul(y, fq.bits(e))
- local rv = add(gs, niels(ye))
+ local gs = ed.mulG(fq.bits(fq.decode(sStr)))
+ local ye = ed.mul(y, fq.bits(e))
+ local rv = ed.add(gs, ed.niels(ye))
- return encode(scale(rv)) == rStr
+ return ed.encode(ed.scale(rv)) == rStr
end
return mod
diff --git a/internal/curve25519.lua b/internal/curve25519.lua
new file mode 100644
index 0000000..83b5999
--- /dev/null
+++ b/internal/curve25519.lua
@@ -0,0 +1,95 @@
+--- Point arithmetic on the Curve25519 Montgomery curve.
+--
+-- :::note Internal Module
+-- This module is meant for internal use within the library. Its API is unstable
+-- and subject to change without major version bumps.
+-- :::
+--
+--
+--
+-- @module[kind=internal] internal.curve25519
+--
+
+local fp = require "ccryptolib.internal.fp"
+local random = require "ccryptolib.random"
+
+local unpack = unpack or table.unpack
+
+local function double(x1, z1)
+ local a = fp.add(x1, z1)
+ local aa = fp.square(a)
+ local b = fp.sub(x1, z1)
+ local bb = fp.square(b)
+ local c = fp.sub(aa, bb)
+ local x3 = fp.mul(aa, bb)
+ local z3 = fp.mul(c, fp.add(bb, fp.kmul(c, 121666)))
+ return x3, z3
+end
+
+local function step(dxmul, dx, x1, z1, x2, z2)
+ local a = fp.add(x1, z1)
+ local aa = fp.square(a)
+ local b = fp.sub(x1, z1)
+ local bb = fp.square(b)
+ local e = fp.sub(aa, bb)
+ local c = fp.add(x2, z2)
+ local d = fp.sub(x2, z2)
+ local da = fp.mul(d, a)
+ local cb = fp.mul(c, b)
+ local x4 = fp.square(fp.add(da, cb))
+ local z4 = dxmul(fp.square(fp.sub(da, cb)), dx)
+ local x3 = fp.mul(aa, bb)
+ local z3 = fp.mul(e, fp.add(bb, fp.kmul(e, 121666)))
+ return x3, z3, x4, z4
+end
+
+local function bits(str)
+ -- Decode.
+ local bytes = {str:byte(1, 32)}
+ local out = {}
+ for i = 1, 32 do
+ local byte = bytes[i]
+ for j = -7, 0 do
+ local bit = byte % 2
+ out[8 * i + j] = bit
+ byte = (byte - bit) / 2
+ end
+ end
+
+ -- Clamp.
+ out[256] = 0
+ out[255] = 1
+
+ -- We remove the 3 lowest bits since the ladder already multiplies by 8.
+ return {unpack(out, 4)}
+end
+
+local function ladder8(dxmul, dx, bits)
+ local x1 = fp.num(1)
+ local z1 = fp.num(0)
+
+ local z2 = fp.decode(random.random(32))
+ local x2 = dxmul(z2, dx)
+
+ -- Standard ladder.
+ for i = #bits, 1, -1 do
+ if bits[i] == 0 then
+ x1, z1, x2, z2 = step(dxmul, dx, x1, z1, x2, z2)
+ else
+ x2, z2, x1, z1 = step(dxmul, dx, x2, z2, x1, z1)
+ end
+ end
+
+ -- Multiply by 8 (double 3 times).
+ for _ = 1, 3 do
+ x1, z1 = double(x1, z1)
+ end
+
+ return fp.mul(x1, fp.invert(z1))
+end
+
+return {
+ double = double,
+ bits = bits,
+ ladder8 = ladder8,
+}
diff --git a/internal/edwards25519.lua b/internal/edwards25519.lua
new file mode 100644
index 0000000..370ae0b
--- /dev/null
+++ b/internal/edwards25519.lua
@@ -0,0 +1,244 @@
+--- Point arithmetic on the Edwards25519 Edwards curve.
+--
+-- :::note Internal Module
+-- This module is meant for internal use within the library. Its API is unstable
+-- and subject to change without major version bumps.
+-- :::
+--
+--
+--
+-- @module[kind=internal] internal.edwards25519
+--
+
+local fp = require "ccryptolib.internal.fp"
+
+local unpack = unpack or table.unpack
+
+local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666)))
+local K = fp.kmul(D, 2)
+
+local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)}
+local G = nil
+
+local function double(P1)
+ -- Unsoundness: fp.sub(g, e), and fp.sub(d, i) break fp.sub's contract since
+ -- it doesn't accept an fp2. Although not ideal, in practice this doesn't
+ -- matter since fp.carry handles the larger sum.
+ local P1x, P1y, P1z = unpack(P1)
+ local a = fp.square(P1x)
+ local b = fp.square(P1y)
+ local c = fp.square(P1z)
+ local d = fp.add(c, c)
+ local e = fp.add(a, b)
+ local f = fp.add(P1x, P1y)
+ local g = fp.square(f)
+ local h = fp.carry(fp.sub(g, e))
+ local i = fp.sub(b, a)
+ local j = fp.carry(fp.sub(d, i))
+ local P3x = fp.mul(h, j)
+ local P3y = fp.mul(i, e)
+ local P3z = fp.mul(j, i)
+ local P3t = fp.mul(h, e)
+ return {P3x, P3y, P3z, P3t}
+end
+
+local function add(P1, N1)
+ local P1x, P1y, P1z, P1t = unpack(P1)
+ local N1p, N1m, N1z, N1t = unpack(N1)
+ local a = fp.sub(P1y, P1x)
+ local b = fp.mul(a, N1m)
+ local c = fp.add(P1y, P1x)
+ local d = fp.mul(c, N1p)
+ local e = fp.mul(P1t, N1t)
+ local f = fp.mul(P1z, N1z)
+ local g = fp.sub(d, b)
+ local h = fp.sub(f, e)
+ local i = fp.add(f, e)
+ local j = fp.add(d, b)
+ local P3x = fp.mul(g, h)
+ local P3y = fp.mul(i, j)
+ local P3z = fp.mul(h, i)
+ local P3t = fp.mul(g, j)
+ return {P3x, P3y, P3z, P3t}
+end
+
+local function sub(P1, N1)
+ local P1x, P1y, P1z, P1t = unpack(P1)
+ local N1p, N1m, N1z, N1t = unpack(N1)
+ local a = fp.sub(P1y, P1x)
+ local b = fp.mul(a, N1p)
+ local c = fp.add(P1y, P1x)
+ local d = fp.mul(c, N1m)
+ local e = fp.mul(P1t, N1t)
+ local f = fp.mul(P1z, N1z)
+ local g = fp.sub(d, b)
+ local h = fp.add(f, e)
+ local i = fp.sub(f, e)
+ local j = fp.add(d, b)
+ local P3x = fp.mul(g, h)
+ local P3y = fp.mul(i, j)
+ local P3z = fp.mul(h, i)
+ local P3t = fp.mul(g, j)
+ return {P3x, P3y, P3z, P3t}
+end
+
+local function niels(P1)
+ local P1x, P1y, P1z, P1t = unpack(P1)
+ local N3p = fp.add(P1y, P1x)
+ local N3m = fp.sub(P1y, P1x)
+ local N3z = fp.add(P1z, P1z)
+ local N3t = fp.mul(P1t, K)
+ return {N3p, N3m, N3z, N3t}
+end
+
+local function scale(P1)
+ local P1x, P1y, P1z = unpack(P1)
+ local zInv = fp.invert(P1z)
+ local P3x = fp.mul(P1x, zInv)
+ local P3y = fp.mul(P1y, zInv)
+ local P3z = fp.num(1)
+ local P3t = fp.mul(P3x, P3y)
+ return {P3x, P3y, P3z, P3t}
+end
+
+local function encode(P1)
+ local P1x, P1y = unpack(P1)
+ local y = fp.encode(P1y)
+ local xBit = fp.canonicalize(P1x)[1] % 2
+ return y:sub(1, -2) .. string.char(y:byte(-1) + xBit * 128)
+end
+
+local function decode(str)
+ local P3y = fp.decode(str)
+ local a = fp.square(P3y)
+ local b = fp.sub(a, fp.num(1))
+ local c = fp.mul(a, D)
+ local d = fp.add(c, fp.num(1))
+ local P3x = fp.sqrtDiv(b, d)
+ if not P3x then return nil end
+ local xBit = fp.canonicalize(P3x)[1] % 2
+ if xBit ~= bit32.extract(str:byte(-1), 7) then
+ P3x = fp.carry(fp.sub(fp.P, P3x))
+ end
+ local P3z = fp.num(1)
+ local P3t = fp.mul(P3x, P3y)
+ return {P3x, P3y, P3z, P3t}
+end
+
+G = decode("Xfffffffffffffffffffffffffffffff")
+
+local function signedRadixW(bits, w)
+ -- TODO Find a more elegant way of doing this.
+ local wPow = 2 ^ w
+ local wPowh = wPow / 2
+ local out = {}
+ local acc = 0
+ local mul = 1
+ for i = 1, #bits do
+ acc = acc + bits[i] * mul
+ mul = mul * 2
+ while i == #bits and acc > 0 or mul > wPow do
+ local rem = acc % wPow
+ if rem >= wPowh then rem = rem - wPow end
+ acc = (acc - rem) / wPow
+ mul = mul / wPow
+ out[#out + 1] = rem
+ end
+ end
+ return out
+end
+
+local function radixWTable(P, w)
+ local out = {}
+ for i = 1, 255 / w do
+ local row = {niels(P)}
+ for j = 2, 2 ^ w / 2 do
+ P = add(P, row[1])
+ row[j] = niels(P)
+ end
+ out[i] = row
+ P = double(P)
+ end
+ return out
+end
+
+local G_W = 5
+local G_TABLE = radixWTable(G, G_W)
+
+local function WNAF(bits, w)
+ -- TODO Find a more elegant way of doing this.
+ local wPow = 2 ^ w
+ local wPowh = wPow / 2
+ local out = {}
+ local acc = 0
+ local mul = 1
+ for i = 1, #bits do
+ acc = acc + bits[i] * mul
+ mul = mul * 2
+ while i == #bits and acc > 0 or mul > wPow do
+ if acc % 2 == 0 then
+ acc = acc / 2
+ mul = mul / 2
+ out[#out + 1] = 0
+ else
+ local rem = acc % wPow
+ if rem >= wPowh then rem = rem - wPow end
+ acc = acc - rem
+ out[#out + 1] = rem
+ end
+ end
+ end
+ while out[#out] == 0 do out[#out] = nil end
+ return out
+end
+
+local function WNAFTable(P, w)
+ local dP = double(P)
+ local out = {niels(P)}
+ for i = 3, 2 ^ w, 2 do
+ out[i] = niels(add(dP, out[i - 2]))
+ end
+ return out
+end
+
+local function mulG(bits)
+ local sw = signedRadixW(bits, G_W)
+ local R = O
+ for i = 1, #sw do
+ local b = sw[i]
+ if b > 0 then
+ R = add(R, G_TABLE[i][b])
+ elseif b < 0 then
+ R = sub(R, G_TABLE[i][-b])
+ end
+ end
+ return R
+end
+
+local function mul(P, bits)
+ local naf = WNAF(bits, 5)
+ local tbl = WNAFTable(P, 5)
+ local R = O
+ for i = #naf, 1, -1 do
+ local b = naf[i]
+ if b == 0 then
+ R = double(R)
+ elseif b > 0 then
+ R = add(R, tbl[b])
+ else
+ R = sub(R, tbl[-b])
+ end
+ end
+ return R
+end
+
+return {
+ double = double,
+ add = add,
+ niels = niels,
+ scale = scale,
+ encode = encode,
+ decode = decode,
+ mulG = mulG,
+ mul = mul,
+}
diff --git a/x25519.lua b/x25519.lua
index 18ad851..c5b1d02 100644
--- a/x25519.lua
+++ b/x25519.lua
@@ -5,82 +5,7 @@
local expect = require "cc.expect".expect
local fp = require "ccryptolib.internal.fp"
-local random = require "ccryptolib.random"
-
-local unpack = unpack or table.unpack
-
-local function double(x1, z1)
- local a = fp.add(x1, z1)
- local aa = fp.square(a)
- local b = fp.sub(x1, z1)
- local bb = fp.square(b)
- local c = fp.sub(aa, bb)
- local x3 = fp.mul(aa, bb)
- local z3 = fp.mul(c, fp.add(bb, fp.kmul(c, 121666)))
- return x3, z3
-end
-
-local function step(dxmul, dx, x1, z1, x2, z2)
- local a = fp.add(x1, z1)
- local aa = fp.square(a)
- local b = fp.sub(x1, z1)
- local bb = fp.square(b)
- local e = fp.sub(aa, bb)
- local c = fp.add(x2, z2)
- local d = fp.sub(x2, z2)
- local da = fp.mul(d, a)
- local cb = fp.mul(c, b)
- local x4 = fp.square(fp.add(da, cb))
- local z4 = dxmul(fp.square(fp.sub(da, cb)), dx)
- local x3 = fp.mul(aa, bb)
- local z3 = fp.mul(e, fp.add(bb, fp.kmul(e, 121666)))
- return x3, z3, x4, z4
-end
-
-local function bits(str)
- -- Decode.
- local bytes = {str:byte(1, 32)}
- local out = {}
- for i = 1, 32 do
- local byte = bytes[i]
- for j = -7, 0 do
- local bit = byte % 2
- out[8 * i + j] = bit
- byte = (byte - bit) / 2
- end
- end
-
- -- Clamp.
- out[256] = 0
- out[255] = 1
-
- -- We remove the 3 lowest bits since the ladder already multiplies by 8.
- return {unpack(out, 4)}
-end
-
-local function ladder8(dxmul, dx, bits)
- local x1 = fp.num(1)
- local z1 = fp.num(0)
-
- local z2 = fp.decode(random.random(32))
- local x2 = dxmul(z2, dx)
-
- -- Standard ladder.
- for i = #bits, 1, -1 do
- if bits[i] == 0 then
- x1, z1, x2, z2 = step(dxmul, dx, x1, z1, x2, z2)
- else
- x2, z2, x1, z1 = step(dxmul, dx, x2, z2, x1, z1)
- end
- end
-
- -- Multiply by 8 (double 3 times).
- for _ = 1, 3 do
- x1, z1 = double(x1, z1)
- end
-
- return fp.mul(x1, fp.invert(z1))
-end
+local mont = require "ccryptolib.internal.curve25519"
local mod = {}
@@ -92,7 +17,7 @@ local mod = {}
function mod.publicKey(sk)
expect(1, sk, "string")
assert(#sk == 32, "secret key length must be 32")
- return fp.encode(ladder8(fp.kmul, 9, bits(sk)))
+ return fp.encode(mont.ladder8(fp.kmul, 9, mont.bits(sk)))
end
--- Performs the key exchange.
@@ -106,7 +31,7 @@ function mod.exchange(sk, pk)
assert(#sk == 32, "secret key length must be 32")
expect(2, pk, "string")
assert(#pk == 32, "public key length must be 32")
- return fp.encode(ladder8(fp.mul, fp.decode(pk), bits(sk)))
+ return fp.encode(mont.ladder8(fp.mul, fp.decode(pk), mont.bits(sk)))
end
return mod