diff --git a/ed25519.lua b/ed25519.lua index a3907fa..1942aaa 100644 --- a/ed25519.lua +++ b/ed25519.lua @@ -1,4 +1,4 @@ ---- The Ed25519 signature scheme. +--- The Ed25519 digital signature scheme. -- -- **Note:** This library is provided for compatibility and provides no side -- channel resistance by itself. @@ -6,228 +6,10 @@ -- @module ed25519 -- -local expect = require "cc.expect".expect -local fp = require "ccryptolib.internal.fp" -local fq = require "ccryptolib.internal.fq" -local sha512 = require "ccryptolib.internal.sha512" - -local unpack = unpack or table.unpack - -local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666))) -local K = fp.kmul(D, 2) - -local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)} -local G = nil - -local function double(P1) - local P1x, P1y, P1z = unpack(P1) - local a = fp.square(P1x) - local b = fp.square(P1y) - local c = fp.square(P1z) - local d = fp.kmul(c, 2) - local e = fp.add(a, b) - local f = fp.add(P1x, P1y) - local g = fp.square(f) - local h = fp.sub(g, e) - local i = fp.sub(b, a) - local j = fp.sub(d, i) - local P3x = fp.mul(h, j) - local P3y = fp.mul(i, e) - local P3z = fp.mul(j, i) - local P3t = fp.mul(h, e) - return {P3x, P3y, P3z, P3t} -end - -local function add(P1, N1) - local P1x, P1y, P1z, P1t = unpack(P1) - local N1p, N1m, N1z, N1t = unpack(N1) - local a = fp.sub(P1y, P1x) - local b = fp.mul(a, N1m) - local c = fp.add(P1y, P1x) - local d = fp.mul(c, N1p) - local e = fp.mul(P1t, N1t) - local f = fp.mul(P1z, N1z) - local g = fp.sub(d, b) - local h = fp.sub(f, e) - local i = fp.add(f, e) - local j = fp.add(d, b) - local P3x = fp.mul(g, h) - local P3y = fp.mul(i, j) - local P3z = fp.mul(h, i) - local P3t = fp.mul(g, j) - return {P3x, P3y, P3z, P3t} -end - -local function sub(P1, N1) - local P1x, P1y, P1z, P1t = unpack(P1) - local N1p, N1m, N1z, N1t = unpack(N1) - local a = fp.sub(P1y, P1x) - local b = fp.mul(a, N1p) - local c = fp.add(P1y, P1x) - local d = fp.mul(c, N1m) - local e = fp.mul(P1t, N1t) - local f = fp.mul(P1z, N1z) - local g = fp.sub(d, b) - local h = fp.add(f, e) - local i = fp.sub(f, e) - local j = fp.add(d, b) - local P3x = fp.mul(g, h) - local P3y = fp.mul(i, j) - local P3z = fp.mul(h, i) - local P3t = fp.mul(g, j) - return {P3x, P3y, P3z, P3t} -end - -local function niels(P1) - local P1x, P1y, P1z, P1t = unpack(P1) - local N3p = fp.add(P1y, P1x) - local N3m = fp.sub(P1y, P1x) - local N3z = fp.add(P1z, P1z) - local N3t = fp.mul(P1t, K) - return {N3p, N3m, N3z, N3t} -end - -local function scale(P1) - local P1x, P1y, P1z = unpack(P1) - local zInv = fp.invert(P1z) - local P3x = fp.mul(P1x, zInv) - local P3y = fp.mul(P1y, zInv) - local P3z = fp.num(1) - local P3t = fp.mul(P3x, P3y) - return {P3x, P3y, P3z, P3t} -end - -local function encode(P1) - local P1x, P1y = unpack(P1) - local y = fp.encode(P1y) - local xBit = fp.canonicalize(P1x)[1] % 2 - return y:sub(1, -2) .. string.char(y:byte(-1) + xBit * 128) -end - -local function decode(str) - local P3y = fp.decode(str) - local a = fp.square(P3y) - local b = fp.sub(a, fp.num(1)) - local c = fp.mul(a, D) - local d = fp.add(c, fp.num(1)) - local P3x = fp.sqrtDiv(b, d) - if not P3x then return nil end - local xBit = fp.canonicalize(P3x)[1] % 2 - if xBit ~= bit32.extract(str:byte(-1), 7) then - P3x = fp.neg(P3x) - P3x = fp.carry(P3x) - end - local P3z = fp.num(1) - local P3t = fp.mul(P3x, P3y) - return {P3x, P3y, P3z, P3t} -end - -G = decode("Xfffffffffffffffffffffffffffffff") - -local function signedRadixW(bits, w) - -- TODO Find a more elegant way of doing this. - local wPow = 2 ^ w - local wPowh = wPow / 2 - local out = {} - local acc = 0 - local mul = 1 - for i = 1, #bits do - acc = acc + bits[i] * mul - mul = mul * 2 - while i == #bits and acc > 0 or mul > wPow do - local rem = acc % wPow - if rem >= wPowh then rem = rem - wPow end - acc = (acc - rem) / wPow - mul = mul / wPow - out[#out + 1] = rem - end - end - return out -end - -local function radixWTable(P, w) - local out = {} - for i = 1, 255 / w do - local row = {niels(P)} - for j = 2, 2 ^ w / 2 do - P = add(P, row[1]) - row[j] = niels(P) - end - out[i] = row - P = double(P) - end - return out -end - -local G_W = 5 -local G_TABLE = radixWTable(G, G_W) - -local function WNAF(bits, w) - -- TODO Find a more elegant way of doing this. - local wPow = 2 ^ w - local wPowh = wPow / 2 - local out = {} - local acc = 0 - local mul = 1 - for i = 1, #bits do - acc = acc + bits[i] * mul - mul = mul * 2 - while i == #bits and acc > 0 or mul > wPow do - if acc % 2 == 0 then - acc = acc / 2 - mul = mul / 2 - out[#out + 1] = 0 - else - local rem = acc % wPow - if rem >= wPowh then rem = rem - wPow end - acc = acc - rem - out[#out + 1] = rem - end - end - end - while out[#out] == 0 do out[#out] = nil end - return out -end - -local function WNAFTable(P, w) - local dP = double(P) - local out = {niels(P)} - for i = 3, 2 ^ w, 2 do - out[i] = niels(add(dP, out[i - 2])) - end - return out -end - -local function mulG(bits) - local sw = signedRadixW(bits, G_W) - local R = O - for i = 1, #sw do - local b = sw[i] - if b > 0 then - R = add(R, G_TABLE[i][b]) - elseif b < 0 then - R = sub(R, G_TABLE[i][-b]) - end - end - return R -end - -local function mul(P, bits) - local naf = WNAF(bits, 5) - local tbl = WNAFTable(P, 5) - local R = O - for i = #naf, 1, -1 do - local b = naf[i] - if b == 0 then - R = double(R) - elseif b > 0 then - R = add(R, tbl[b]) - else - R = sub(R, tbl[-b]) - end - end - return R -end +local expect = require "cc.expect".expect +local fq = require "ccryptolib.internal.fq" +local ed25519 = require "ccryptolib.internal.ed25519" +local sha512 = require "ccryptolib.internal.sha512" local mod = {} @@ -243,7 +25,7 @@ function mod.publicKey(sk) local h = sha512.digest(sk) local x = fq.decodeClamped(h:sub(1, 32)) - return encode(scale(mulG(fq.bits(x)))) + return ed25519.encode(ed25519.scale(ed25519.mulG(fq.bits(x)))) end --- Signs a message. @@ -266,8 +48,8 @@ function mod.sign(sk, pk, msg) -- Commitment. local k = fq.decodeWide(sha512.digest(h:sub(33) .. msg)) - local r = mulG(fq.bits(k)) - local rStr = encode(scale(r)) + local r = ed25519.mulG(fq.bits(k)) + local rStr = ed25519.encode(ed25519.scale(r)) -- Challenge. local e = fq.decodeWide(sha512.digest(rStr .. pk .. msg)) @@ -293,7 +75,7 @@ function mod.verify(pk, msg, sig) expect(3, sig, "string") assert(#sig == 64, "signature length must be 64") - local y = decode(pk) + local y = ed25519.decode(pk) if not y then return nil end local rStr = sig:sub(1, 32) @@ -301,11 +83,11 @@ function mod.verify(pk, msg, sig) local e = fq.decodeWide(sha512.digest(rStr .. pk .. msg)) - local gs = mulG(fq.bits(fq.decode(sStr))) - local ye = mul(y, fq.bits(e)) - local rv = add(gs, niels(ye)) + local gs = ed25519.mulG(fq.bits(fq.decode(sStr))) + local ye = ed25519.mul(y, fq.bits(e)) + local rv = ed25519.add(gs, ed25519.niels(ye)) - return encode(scale(rv)) == rStr + return ed25519.encode(ed25519.scale(rv)) == rStr end return mod diff --git a/internal/ed25519.lua b/internal/ed25519.lua new file mode 100644 index 0000000..6ac4e5d --- /dev/null +++ b/internal/ed25519.lua @@ -0,0 +1,229 @@ +local fp = require "ccryptolib.internal.fp" + +local unpack = unpack or table.unpack + +local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666))) +local K = fp.kmul(D, 2) + +local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)} +local G = nil + +local function double(P1) + local P1x, P1y, P1z = unpack(P1) + local a = fp.square(P1x) + local b = fp.square(P1y) + local c = fp.square(P1z) + local d = fp.kmul(c, 2) + local e = fp.add(a, b) + local f = fp.add(P1x, P1y) + local g = fp.square(f) + local h = fp.sub(g, e) + local i = fp.sub(b, a) + local j = fp.sub(d, i) + local P3x = fp.mul(h, j) + local P3y = fp.mul(i, e) + local P3z = fp.mul(j, i) + local P3t = fp.mul(h, e) + return {P3x, P3y, P3z, P3t} +end + +local function add(P1, N1) + local P1x, P1y, P1z, P1t = unpack(P1) + local N1p, N1m, N1z, N1t = unpack(N1) + local a = fp.sub(P1y, P1x) + local b = fp.mul(a, N1m) + local c = fp.add(P1y, P1x) + local d = fp.mul(c, N1p) + local e = fp.mul(P1t, N1t) + local f = fp.mul(P1z, N1z) + local g = fp.sub(d, b) + local h = fp.sub(f, e) + local i = fp.add(f, e) + local j = fp.add(d, b) + local P3x = fp.mul(g, h) + local P3y = fp.mul(i, j) + local P3z = fp.mul(h, i) + local P3t = fp.mul(g, j) + return {P3x, P3y, P3z, P3t} +end + +local function sub(P1, N1) + local P1x, P1y, P1z, P1t = unpack(P1) + local N1p, N1m, N1z, N1t = unpack(N1) + local a = fp.sub(P1y, P1x) + local b = fp.mul(a, N1p) + local c = fp.add(P1y, P1x) + local d = fp.mul(c, N1m) + local e = fp.mul(P1t, N1t) + local f = fp.mul(P1z, N1z) + local g = fp.sub(d, b) + local h = fp.add(f, e) + local i = fp.sub(f, e) + local j = fp.add(d, b) + local P3x = fp.mul(g, h) + local P3y = fp.mul(i, j) + local P3z = fp.mul(h, i) + local P3t = fp.mul(g, j) + return {P3x, P3y, P3z, P3t} +end + +local function niels(P1) + local P1x, P1y, P1z, P1t = unpack(P1) + local N3p = fp.add(P1y, P1x) + local N3m = fp.sub(P1y, P1x) + local N3z = fp.add(P1z, P1z) + local N3t = fp.mul(P1t, K) + return {N3p, N3m, N3z, N3t} +end + +local function scale(P1) + local P1x, P1y, P1z = unpack(P1) + local zInv = fp.invert(P1z) + local P3x = fp.mul(P1x, zInv) + local P3y = fp.mul(P1y, zInv) + local P3z = fp.num(1) + local P3t = fp.mul(P3x, P3y) + return {P3x, P3y, P3z, P3t} +end + +local function encode(P1) + local P1x, P1y = unpack(P1) + local y = fp.encode(P1y) + local xBit = fp.canonicalize(P1x)[1] % 2 + return y:sub(1, -2) .. string.char(y:byte(-1) + xBit * 128) +end + +local function decode(str) + local P3y = fp.decode(str) + local a = fp.square(P3y) + local b = fp.sub(a, fp.num(1)) + local c = fp.mul(a, D) + local d = fp.add(c, fp.num(1)) + local P3x = fp.sqrtDiv(b, d) + if not P3x then return nil end + local xBit = fp.canonicalize(P3x)[1] % 2 + if xBit ~= bit32.extract(str:byte(-1), 7) then + P3x = fp.neg(P3x) + P3x = fp.carry(P3x) + end + local P3z = fp.num(1) + local P3t = fp.mul(P3x, P3y) + return {P3x, P3y, P3z, P3t} +end + +G = decode("Xfffffffffffffffffffffffffffffff") + +local function signedRadixW(bits, w) + -- TODO Find a more elegant way of doing this. + local wPow = 2 ^ w + local wPowh = wPow / 2 + local out = {} + local acc = 0 + local mul = 1 + for i = 1, #bits do + acc = acc + bits[i] * mul + mul = mul * 2 + while i == #bits and acc > 0 or mul > wPow do + local rem = acc % wPow + if rem >= wPowh then rem = rem - wPow end + acc = (acc - rem) / wPow + mul = mul / wPow + out[#out + 1] = rem + end + end + return out +end + +local function radixWTable(P, w) + local out = {} + for i = 1, 255 / w do + local row = {niels(P)} + for j = 2, 2 ^ w / 2 do + P = add(P, row[1]) + row[j] = niels(P) + end + out[i] = row + P = double(P) + end + return out +end + +local G_W = 5 +local G_TABLE = radixWTable(G, G_W) + +local function WNAF(bits, w) + -- TODO Find a more elegant way of doing this. + local wPow = 2 ^ w + local wPowh = wPow / 2 + local out = {} + local acc = 0 + local mul = 1 + for i = 1, #bits do + acc = acc + bits[i] * mul + mul = mul * 2 + while i == #bits and acc > 0 or mul > wPow do + if acc % 2 == 0 then + acc = acc / 2 + mul = mul / 2 + out[#out + 1] = 0 + else + local rem = acc % wPow + if rem >= wPowh then rem = rem - wPow end + acc = acc - rem + out[#out + 1] = rem + end + end + end + while out[#out] == 0 do out[#out] = nil end + return out +end + +local function WNAFTable(P, w) + local dP = double(P) + local out = {niels(P)} + for i = 3, 2 ^ w, 2 do + out[i] = niels(add(dP, out[i - 2])) + end + return out +end + +local function mulG(bits) + local sw = signedRadixW(bits, G_W) + local R = O + for i = 1, #sw do + local b = sw[i] + if b > 0 then + R = add(R, G_TABLE[i][b]) + elseif b < 0 then + R = sub(R, G_TABLE[i][-b]) + end + end + return R +end + +local function mul(P, bits) + local naf = WNAF(bits, 5) + local tbl = WNAFTable(P, 5) + local R = O + for i = #naf, 1, -1 do + local b = naf[i] + if b == 0 then + R = double(R) + elseif b > 0 then + R = add(R, tbl[b]) + else + R = sub(R, tbl[-b]) + end + end + return R +end + +return { + add = add, + niels = niels, + scale = scale, + encode = encode, + decode = decode, + mulG = mulG, + mul = mul, +}