295 lines
7.4 KiB
Lua
295 lines
7.4 KiB
Lua
--- The Ed25519 signature scheme.
|
|
--
|
|
-- @module ed25519
|
|
--
|
|
|
|
local expect = require "cc.expect".expect
|
|
local fp = require "ccryptolib.fp"
|
|
local fq = require "ccryptolib.fq"
|
|
local sha512 = require "ccryptolib.sha512"
|
|
local random = require "ccryptolib.random"
|
|
local util = require "ccryptolib.util"
|
|
|
|
local unpack = unpack or table.unpack
|
|
|
|
local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666)))
|
|
local K = fp.kmul(D, 2)
|
|
|
|
local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)}
|
|
local G = nil
|
|
|
|
local function double(P1)
|
|
local P1x, P1y, P1z = unpack(P1)
|
|
local a = fp.square(P1x)
|
|
local b = fp.square(P1y)
|
|
local c = fp.square(P1z)
|
|
local d = fp.kmul(c, 2)
|
|
local e = fp.add(a, b)
|
|
local f = fp.add(P1x, P1y)
|
|
local g = fp.square(f)
|
|
local h = fp.sub(g, e)
|
|
local i = fp.sub(b, a)
|
|
local j = fp.sub(d, i)
|
|
local P3x = fp.mul(h, j)
|
|
local P3y = fp.mul(i, e)
|
|
local P3z = fp.mul(j, i)
|
|
local P3t = fp.mul(h, e)
|
|
return {P3x, P3y, P3z, P3t}
|
|
end
|
|
|
|
local function add(P1, N1)
|
|
local P1x, P1y, P1z, P1t = unpack(P1)
|
|
local N1p, N1m, N1z, N1t = unpack(N1)
|
|
local a = fp.sub(P1y, P1x)
|
|
local b = fp.mul(a, N1m)
|
|
local c = fp.add(P1y, P1x)
|
|
local d = fp.mul(c, N1p)
|
|
local e = fp.mul(P1t, N1t)
|
|
local f = fp.mul(P1z, N1z)
|
|
local g = fp.sub(d, b)
|
|
local h = fp.sub(f, e)
|
|
local i = fp.add(f, e)
|
|
local j = fp.add(d, b)
|
|
local P3x = fp.mul(g, h)
|
|
local P3y = fp.mul(i, j)
|
|
local P3z = fp.mul(h, i)
|
|
local P3t = fp.mul(g, j)
|
|
return {P3x, P3y, P3z, P3t}
|
|
end
|
|
|
|
local function sub(P1, N1)
|
|
local P1x, P1y, P1z, P1t = unpack(P1)
|
|
local N1p, N1m, N1z, N1t = unpack(N1)
|
|
local a = fp.sub(P1y, P1x)
|
|
local b = fp.mul(a, N1p)
|
|
local c = fp.add(P1y, P1x)
|
|
local d = fp.mul(c, N1m)
|
|
local e = fp.mul(P1t, N1t)
|
|
local f = fp.mul(P1z, N1z)
|
|
local g = fp.sub(d, b)
|
|
local h = fp.add(f, e)
|
|
local i = fp.sub(f, e)
|
|
local j = fp.add(d, b)
|
|
local P3x = fp.mul(g, h)
|
|
local P3y = fp.mul(i, j)
|
|
local P3z = fp.mul(h, i)
|
|
local P3t = fp.mul(g, j)
|
|
return {P3x, P3y, P3z, P3t}
|
|
end
|
|
|
|
local function niels(P1)
|
|
local P1x, P1y, P1z, P1t = unpack(P1)
|
|
local N3p = fp.add(P1y, P1x)
|
|
local N3m = fp.sub(P1y, P1x)
|
|
local N3z = fp.add(P1z, P1z)
|
|
local N3t = fp.mul(P1t, K)
|
|
return {N3p, N3m, N3z, N3t}
|
|
end
|
|
|
|
local function scale(P1)
|
|
local P1x, P1y, P1z = unpack(P1)
|
|
local zInv = fp.invert(P1z)
|
|
local P3x = fp.mul(P1x, zInv)
|
|
local P3y = fp.mul(P1y, zInv)
|
|
local P3z = fp.num(1)
|
|
local P3t = fp.mul(P3x, P3y)
|
|
return {P3x, P3y, P3z, P3t}
|
|
end
|
|
|
|
local function encode(P1)
|
|
local P1x, P1y = unpack(P1)
|
|
local y = fp.encode(P1y)
|
|
local xBit = fp.canonicalize(P1x)[1] % 2
|
|
return y:sub(1, -2) .. string.char(y:byte(-1) + xBit * 128)
|
|
end
|
|
|
|
local function decode(str)
|
|
local P3y = fp.decode(str)
|
|
local a = fp.square(P3y)
|
|
local b = fp.sub(a, fp.num(1))
|
|
local c = fp.mul(a, D)
|
|
local d = fp.add(c, fp.num(1))
|
|
local P3x = fp.sqrtDiv(b, d)
|
|
if not P3x then return nil end
|
|
local xBit = fp.canonicalize(P3x)[1] % 2
|
|
if xBit ~= bit32.extract(str:byte(-1), 7) then
|
|
P3x = fp.neg(P3x)
|
|
P3x = fp.carry(P3x)
|
|
end
|
|
local P3z = fp.num(1)
|
|
local P3t = fp.mul(P3x, P3y)
|
|
return {P3x, P3y, P3z, P3t}
|
|
end
|
|
|
|
G = decode("Xfffffffffffffffffffffffffffffff")
|
|
|
|
local function signedRadixW(bits, w)
|
|
-- TODO Find a more elegant way of doing this.
|
|
local wPow = 2 ^ w
|
|
local wPowh = wPow / 2
|
|
local out = {}
|
|
local acc = 0
|
|
local mul = 1
|
|
for i = 1, #bits do
|
|
acc = acc + bits[i] * mul
|
|
mul = mul * 2
|
|
while i == #bits and acc > 0 or mul > wPow do
|
|
local rem = acc % wPow
|
|
if rem >= wPowh then rem = rem - wPow end
|
|
acc = (acc - rem) / wPow
|
|
mul = mul / wPow
|
|
out[#out + 1] = rem
|
|
end
|
|
end
|
|
return out
|
|
end
|
|
|
|
local function radixWTable(P, w)
|
|
local out = {}
|
|
for i = 1, 255 / w do
|
|
local row = {niels(P)}
|
|
for j = 2, 2 ^ w / 2 do
|
|
P = add(P, row[1])
|
|
row[j] = niels(P)
|
|
end
|
|
out[i] = row
|
|
P = double(P)
|
|
end
|
|
return out
|
|
end
|
|
|
|
local G_W = 5
|
|
local G_TABLE = radixWTable(G, G_W)
|
|
|
|
local function WNAF(bits, w)
|
|
-- TODO Find a more elegant way of doing this.
|
|
local wPow = 2 ^ w
|
|
local wPowh = wPow / 2
|
|
local out = {}
|
|
local acc = 0
|
|
local mul = 1
|
|
for i = 1, #bits do
|
|
acc = acc + bits[i] * mul
|
|
mul = mul * 2
|
|
while i == #bits and acc > 0 or mul > wPow do
|
|
if acc % 2 == 0 then
|
|
acc = acc / 2
|
|
mul = mul / 2
|
|
out[#out + 1] = 0
|
|
else
|
|
local rem = acc % wPow
|
|
if rem >= wPowh then rem = rem - wPow end
|
|
acc = acc - rem
|
|
out[#out + 1] = rem
|
|
end
|
|
end
|
|
end
|
|
while out[#out] == 0 do out[#out] = nil end
|
|
return out
|
|
end
|
|
|
|
local function WNAFTable(P, w)
|
|
local dP = double(P)
|
|
local out = {niels(P)}
|
|
for i = 3, 2 ^ w, 2 do
|
|
out[i] = niels(add(dP, out[i - 2]))
|
|
end
|
|
return out
|
|
end
|
|
|
|
local function mulG(bits)
|
|
local sw = signedRadixW(bits, G_W)
|
|
local R = O
|
|
for i = 1, #sw do
|
|
local b = sw[i]
|
|
if b > 0 then
|
|
R = add(R, G_TABLE[i][b])
|
|
elseif b < 0 then
|
|
R = sub(R, G_TABLE[i][-b])
|
|
end
|
|
end
|
|
return R
|
|
end
|
|
|
|
local function mul(P, bits)
|
|
local naf = WNAF(bits, 5)
|
|
local tbl = WNAFTable(P, 5)
|
|
local R = O
|
|
for i = #naf, 1, -1 do
|
|
local b = naf[i]
|
|
if b == 0 then
|
|
R = double(R)
|
|
elseif b > 0 then
|
|
R = add(R, tbl[b])
|
|
else
|
|
R = sub(R, tbl[-b])
|
|
end
|
|
end
|
|
return R
|
|
end
|
|
|
|
local function publicKey(sk)
|
|
expect(1, sk, "string")
|
|
assert(#sk == 32, "secret key length must be 32")
|
|
-- FIXME SHA512 isn't constant-time.
|
|
local h = sha512.digest(sk):sub(1, 32)
|
|
local kBits, xkInvBits = blinding.decodeBlinded(h)
|
|
local Y0 = mulG(xkInvBits)
|
|
local Y1 = mul(Y0, kBits)
|
|
return encode(scale(Y1))
|
|
end
|
|
|
|
local function sign(sk, pk, msg)
|
|
expect(1, sk, "string")
|
|
assert(#sk == 32, "secret key length must be 32")
|
|
expect(2, pk, "string")
|
|
assert(#pk == 32, "public key length must be 32")
|
|
expect(3, msg, "string")
|
|
|
|
-- Decode cwords and clamp with a mask.
|
|
local h = sha512.digest(sk):sub(1, 32)
|
|
local xm, m = fq.maskedDecode(h, random.random(32))
|
|
|
|
-- Commitment.
|
|
local k = fq.decodeWide(random.random(64))
|
|
local kBits = fq.bits(k)
|
|
local R = mulG(kBits)
|
|
|
|
-- Challenge.
|
|
local rStr = encode(scale(R))
|
|
local e = fq.decodeWide(sha512.digest(rStr .. pk .. msg))
|
|
|
|
-- Response.
|
|
local exm = fq.mul(e, xm)
|
|
local em = fq.mul(e, m)
|
|
local s = fq.add(fq.sub(k, exm), em)
|
|
local sStr = fq.encode(s)
|
|
|
|
return rStr .. sStr
|
|
end
|
|
|
|
local function verify(pk, msg, sig)
|
|
expect(1, pk, "string")
|
|
assert(#pk == 32, "public key length must be 32")
|
|
expect(2, msg, "string")
|
|
expect(3, sig, "string")
|
|
assert(#sig == 64, "public key length must be 32")
|
|
|
|
local rStr = sig:sub(1, 32)
|
|
local sStr = sig:sub(33)
|
|
local Y = decode(pk)
|
|
|
|
local ev = fq.decodeWide(sha512.digest(rStr .. pk .. msg))
|
|
local evBits = fq.bits(ev)
|
|
local sBits = util.rebaseLE({sStr:byte(1, -1)}, 256, 2)
|
|
local Rv = add(mulG(sBits), mul(Y, evBits))
|
|
|
|
return encode(scale(Rv)) == rStr
|
|
end
|
|
|
|
return {
|
|
publicKey = publicKey,
|
|
sign = sign,
|
|
verify = verify,
|
|
}
|