Rework comments to new annotation style

This commit is contained in:
Miguel Oliveira 2023-06-08 01:15:16 -03:00
parent 6fbbab378a
commit cb620cfb0a
23 changed files with 451 additions and 575 deletions

View file

@ -1,7 +1,4 @@
--- The ChaCha20Poly1305AEAD authenticated encryption with associated data (AEAD) construction. --- The ChaCha20Poly1305AEAD authenticated encryption with associated data (AEAD) construction.
--
-- @module aead
--
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local lassert = require "ccryptolib.internal.util".lassert local lassert = require "ccryptolib.internal.util".lassert
@ -14,15 +11,13 @@ local u4x4, fmt4x4 = packing.compileUnpack("<I4I4I4I4")
local bxor = bit32.bxor local bxor = bit32.bxor
--- Encrypts a message. --- Encrypts a message.
-- --- @param key string A 32-byte random key.
-- @tparam string key A 32-byte random key. --- @param nonce string A 12-byte per-message unique nonce.
-- @tparam string nonce A 12-byte per-message unique nonce. --- @param message string The message to be encrypted.
-- @tparam string message The message to be encrypted. --- @param aad string aad Arbitrary associated data to also authenticate.
-- @tparam string aad Arbitrary associated data to authenticate on decryption. --- @param rounds number? The number of ChaCha20 rounds to use. Defaults to 20.
-- @tparam[opt=20] number rounds The number of ChaCha20 rounds to use. --- @return string ctx The ciphertext.
-- @treturn string The ciphertext. --- @return string tag The 16-byte authentication tag.
-- @treturn string The 16-byte authentication tag.
--
local function encrypt(key, nonce, message, aad, rounds) local function encrypt(key, nonce, message, aad, rounds)
expect(1, key, "string") expect(1, key, "string")
lassert(#key == 32, "key length must be 32", 2) lassert(#key == 32, "key length must be 32", 2)
@ -53,16 +48,13 @@ local function encrypt(key, nonce, message, aad, rounds)
end end
--- Decrypts a message. --- Decrypts a message.
-- --- @param key string The key used on encryption.
-- @tparam string key The key used on encryption. --- @param nonce string The nonce used on encryption.
-- @tparam string nonce The nonce used on encryption. --- @param ciphertext string The ciphertext to be decrypted.
-- @tparam string ciphertext The ciphertext to be decrypted. --- @param aad string The arbitrary associated data used on encryption.
-- @tparam string aad The arbitrary associated data used on encryption. --- @param tag string The authentication tag returned on encryption.
-- @tparam string tag The authentication tag returned on encryption. --- @param rounds number The number of rounds used on encryption.
-- @tparam[opt=20] number rounds The number of rounds used on encryption. --- @return string? msg The decrypted plaintext. Or nil on auth failure.
-- @treturn[1] string The decrypted plaintext.
-- @treturn[2] nil If authentication has failed.
--
local function decrypt(key, nonce, tag, ciphertext, aad, rounds) local function decrypt(key, nonce, tag, ciphertext, aad, rounds)
expect(1, key, "string") expect(1, key, "string")
lassert(#key == 32, "key length must be 32", 2) lassert(#key == 32, "key length must be 32", 2)

View file

@ -1,7 +1,4 @@
--- The BLAKE3 cryptographic hash function. --- The BLAKE3 cryptographic hash function.
--
-- @module blake3
--
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local lassert = require "ccryptolib.internal.util".lassert local lassert = require "ccryptolib.internal.util".lassert
@ -208,11 +205,9 @@ end
local mod = {} local mod = {}
--- Hashes data using BLAKE3. --- Hashes data using BLAKE3.
-- --- @param message string The input message.
-- @tparam string message The input message. --- @param len number? The desired hash length, in bytes. Defaults to 32.
-- @tparam[opt=32] number len The desired hash length, in bytes. --- @return string hash The hash.
-- @treturn string The hash.
--
function mod.digest(message, len) function mod.digest(message, len)
expect(1, message, "string") expect(1, message, "string")
len = expect(2, len, "number", "nil") or 32 len = expect(2, len, "number", "nil") or 32
@ -222,12 +217,10 @@ function mod.digest(message, len)
end end
--- Performs a keyed hash. --- Performs a keyed hash.
-- --- @param key string A 32-byte random key.
-- @tparam string key A 32-byte random key. --- @param message string The input message.
-- @tparam string message The input message. --- @param len number? The desired hash length, in bytes. Defaults to 32.
-- @tparam[opt=32] number len The desired hash length, in bytes. --- @return string hash The keyed hash.
-- @treturn string The keyed hash.
--
function mod.digestKeyed(key, message, len) function mod.digestKeyed(key, message, len)
expect(1, key, "string") expect(1, key, "string")
lassert(#key == 32, "key length must be 32", 2) lassert(#key == 32, "key length must be 32", 2)
@ -239,14 +232,15 @@ function mod.digestKeyed(key, message, len)
end end
--- Makes a context-based key derivation function (KDF). --- Makes a context-based key derivation function (KDF).
-- --- @param context string The context for the KDF.
-- @tparam string context The context for the KDF. --- @return fun(material: string, len: number?): string kdf The KDF.
-- @treturn function(material:string [, len:number]):string The KDF.
--
function mod.deriveKey(context) function mod.deriveKey(context)
expect(1, context, "string") expect(1, context, "string")
local iv = {u8x4(fmt8x4, blake3(IV, DERIVE_KEY_CONTEXT, context, 32), 1)} local iv = {u8x4(fmt8x4, blake3(IV, DERIVE_KEY_CONTEXT, context, 32), 1)}
--- Derives a key.
--- @param material string The keying material.
--- @param len number? The desired hash length, in bytes. Defaults to 32.
return function(material, len) return function(material, len)
expect(1, material, "string") expect(1, material, "string")
len = expect(2, len, "number", "nil") or 32 len = expect(2, len, "number", "nil") or 32

View file

@ -1,7 +1,4 @@
--- The ChaCha20 stream cipher. --- The ChaCha20 stream cipher.
--
-- @module chacha20
--
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local lassert = require "ccryptolib.internal.util".lassert local lassert = require "ccryptolib.internal.util".lassert
@ -17,14 +14,12 @@ local u16x4 = packing.compileUnpack(fmt16x4)
local mod = {} local mod = {}
--- Encrypts/Decrypts data using ChaCha20. --- Encrypts/Decrypts data using ChaCha20.
-- --- @param key string A 32-byte random key.
-- @tparam string key A 32-byte random key. --- @param nonce string A 12-byte per-message unique nonce.
-- @tparam string nonce A 12-byte per-message unique nonce. --- @param message string A plaintext or ciphertext.
-- @tparam string message A plaintext or ciphertext. --- @param rounds number? The number of ChaCha20 rounds to use. Defaults to 20.
-- @tparam[opt=20] number rounds The number of ChaCha20 rounds to use. --- @param offset number? The block offset to generate the keystream at. Defaults to 1.
-- @tparam[opt=1] number offset The block offset to generate the keystream at. --- @return string out The resulting ciphertext or plaintext.
-- @treturn string The resulting ciphertext or plaintext.
--
function mod.crypt(key, nonce, message, rounds, offset) function mod.crypt(key, nonce, message, rounds, offset)
expect(1, key, "string") expect(1, key, "string")
lassert(#key == 32, "key length must be 32", 2) lassert(#key == 32, "key length must be 32", 2)

View file

@ -1,7 +1,4 @@
--- The Ed25519 digital signature scheme. --- The Ed25519 digital signature scheme.
--
-- @module ed25519
--
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local lassert = require "ccryptolib.internal.util".lassert local lassert = require "ccryptolib.internal.util".lassert
@ -13,10 +10,8 @@ local random = require "ccryptolib.random"
local mod = {} local mod = {}
--- Computes a public key from a secret key. --- Computes a public key from a secret key.
-- --- @param sk string A random 32-byte secret key.
-- @tparam string sk A random 32-byte secret key. --- @return string pk The matching 32-byte public key.
-- @treturn string The matching 32-byte public key.
--
function mod.publicKey(sk) function mod.publicKey(sk)
expect(1, sk, "string") expect(1, sk, "string")
assert(#sk == 32, "secret key length must be 32") assert(#sk == 32, "secret key length must be 32")
@ -28,12 +23,10 @@ function mod.publicKey(sk)
end end
--- Signs a message. --- Signs a message.
-- --- @param sk string The signer's secret key.
-- @tparam string sk The signer's secret key. --- @param pk string The signer's public key.
-- @tparam string pk The signer's public key. --- @param msg string The message to be signed.
-- @tparam string msg The message to be signed. --- @return string sig The 64-byte signature on the message.
-- @treturn string The 64-byte signature on the message.
--
function mod.sign(sk, pk, msg) function mod.sign(sk, pk, msg)
expect(1, sk, "string") expect(1, sk, "string")
lassert(#sk == 32, "secret key length must be 32", 2) lassert(#sk == 32, "secret key length must be 32", 2)
@ -62,21 +55,19 @@ function mod.sign(sk, pk, msg)
end end
--- Verifies a signature on a message. --- Verifies a signature on a message.
-- --- @param pk string The signer's public key.
-- @tparam string pk The signer's public key. --- @param msg string The signed message.
-- @tparam string msg The signed message. --- @param sig string The alleged signature.
-- @tparam string sig The signature. --- @return boolean valid Whether the signature is valid or not.
-- @treturn boolean Whether the signature is valid or not.
--
function mod.verify(pk, msg, sig) function mod.verify(pk, msg, sig)
expect(1, pk, "string") expect(1, pk, "string")
lassert(#pk == 32, "public key length must be 32", 2) lassert(#pk == 32, "public key length must be 32", 2) --- @cast pk String32
expect(2, msg, "string") expect(2, msg, "string")
expect(3, sig, "string") expect(3, sig, "string")
lassert(#sig == 64, "signature length must be 64", 2) lassert(#sig == 64, "signature length must be 64", 2)
local y = ed.decode(pk) local y = ed.decode(pk)
if not y then return nil end if not y then return false end
local rStr = sig:sub(1, 32) local rStr = sig:sub(1, 32)
local sStr = sig:sub(33) local sStr = sig:sub(33)

View file

@ -1,19 +1,16 @@
--- Point arithmetic on the Curve25519 Montgomery curve. --- Point arithmetic on the Curve25519 Montgomery curve.
--
-- :::note Internal Module
-- This module is meant for internal use within the library. Its API is unstable
-- and subject to change without major version bumps.
-- :::
--
-- <br />
--
-- @module[kind=internal] internal.curve25519
--
local fp = require "ccryptolib.internal.fp" local fp = require "ccryptolib.internal.fp"
local ed = require "ccryptolib.internal.edwards25519" local ed = require "ccryptolib.internal.edwards25519"
local random = require "ccryptolib.random" local random = require "ccryptolib.random"
--- @class MtPoint A point class on Curve25519, in XZ coordinates.
--- @field [1] number[] The X coordinate.
--- @field [2] number[] The Z coordinate.
--- Doubles a point.
--- @param P1 MtPoint The point to double.
--- @return MtPoint P2 P1 + P1.
local function double(P1) local function double(P1)
local x1, z1 = P1[1], P1[2] local x1, z1 = P1[1], P1[2]
local a = fp.add(x1, z1) local a = fp.add(x1, z1)
@ -26,6 +23,11 @@ local function double(P1)
return {x3, z3} return {x3, z3}
end end
--- Computes differential addition on two points.
--- @param DP MtPoint P1 - P2.
--- @param P1 MtPoint The first point to add.
--- @param P2 MtPoint The second point to add.
--- @return MtPoint P3 P1 + P2.
local function dadd(DP, P1, P2) local function dadd(DP, P1, P2)
local dx, dz = DP[1], DP[2] local dx, dz = DP[1], DP[2]
local x1, z1 = P1[1], P1[2] local x1, z1 = P1[1], P1[2]
@ -42,13 +44,11 @@ local function dadd(DP, P1, P2)
end end
--- Performs a step on the Montgomery ladder. --- Performs a step on the Montgomery ladder.
-- --- @param DP MtPoint P1 - P2.
-- @param C A - B. --- @param P1 MtPoint The first point.
-- @param A The first point. --- @param P2 MtPoint The second point.
-- @param B The second point. --- @return MtPoint P3 2A
-- @return 2A --- @return MtPoint P4 A + B
-- @return A + B
--
local function step(DP, P1, P2) local function step(DP, P1, P2)
local dx, dz = DP[1], DP[2] local dx, dz = DP[1], DP[2]
local x1, z1 = P1[1], P1[2] local x1, z1 = P1[1], P1[2]
@ -85,50 +85,46 @@ local function ladder(DP, bits)
end end
--- Performs a scalar multiplication operation with multiplication by 8. --- Performs a scalar multiplication operation with multiplication by 8.
-- --- @param P MtPoint The base point.
-- @tparam point P The base point. --- @param bits number[] The scalar multiplier, in little-endian bits.
-- @tparam {number...} bits The scalar multiplier, in little-endian bits. --- @return MtPoint product The product, multiplied by 8.
-- @treturn point The product, multiplied by 8.
--
local function ladder8(P, bits) local function ladder8(P, bits)
-- Randomize. -- Randomize.
local rf = fp.decode(random.random(32)) local rf = fp.decode(random.random(32) --[[@as String32, length is given]])
P = {fp.mul(P[1], rf), fp.mul(P[2], rf)} P = {fp.mul(P[1], rf), fp.mul(P[2], rf)}
-- Multiply. -- Multiply.
return double(double(double(ladder(P, bits)))) return double(double(double(ladder(P, bits))))
end end
--- Scales a point's coordinates.
--- @param P MtPoint The input point.
--- @return MtPoint Q The same point P, but with Z = 1.
local function scale(P) local function scale(P)
return {fp.mul(P[1], fp.invert(P[2])), fp.num(1)} return {fp.mul(P[1], fp.invert(P[2])), fp.num(1)}
end end
--- Encodes a point. --- Encodes a scaled point.
-- --- @param P MtPoint The scaled point to encode.
-- @tparam point P1 The scaled point to encode. --- @return string encoded P, encoded into a 32-byte string.
-- @treturn string The 32-byte encoded point.
--
local function encode(P) local function encode(P)
return fp.encode(P[1]) return fp.encode(P[1])
end end
--- Decodes a point. --- Decodes a point.
-- --- @param str String32 A 32-byte encoded point.
-- @tparam string str A 32-byte encoded point. --- @return MtPoint pt The decoded point.
-- @treturn point The decoded point.
--
local function decode(str) local function decode(str)
return {fp.decode(str), fp.num(1)} return {fp.decode(str), fp.num(1)}
end end
--- Decodes an Edwards25519 encoded point into Curve25519, ignoring the sign. --- Decodes an Edwards25519 encoded point into Curve25519, ignoring the sign.
-- ---
-- There is a single exception: The identity point (0, 1), which gets mapped --- There is a single exception: The identity point (0, 1), which gets mapped
-- into the 2-torsion point (0, 0), which isn't the identity of Curve25519. --- into the 2-torsion point (0, 0), which isn't the identity of Curve25519.
-- ---
-- @tparam string str A 32-byte encoded Edwards25519 point. --- @param str String32 A 32-byte encoded Edwards25519 point.
-- @treturn point The decoded point, mapped into Curve25519. --- @return MtPoint pt The decoded point, mapped into Curve25519.
--
local function decodeEd(str) local function decodeEd(str)
local y = fp.decode(str) local y = fp.decode(str)
local n = fp.carry(fp.add(fp.num(1), y)) local n = fp.carry(fp.add(fp.num(1), y))
@ -141,10 +137,8 @@ local function decodeEd(str)
end end
--- Performs a scalar multiplication by the base point G. --- Performs a scalar multiplication by the base point G.
-- --- @param bits number[] The scalar multiplier, in little-endian bits.
-- @tparam {number...} bits The scalar multiplier, in little-endian bits. --- @return MtPoint product The product point.
-- @return The product point.
--
local function mulG(bits) local function mulG(bits)
-- Multiply by G on Edwards25519. -- Multiply by G on Edwards25519.
local P = ed.mulG(bits) local P = ed.mulG(bits)
@ -159,17 +153,17 @@ local function mulG(bits)
end end
--- Computes a twofold product from a ruleset. --- Computes a twofold product from a ruleset.
-- ---
-- @tparam point P The base point. --- Returns nil if any of the results would be equal to the identity.
-- @tparam {{number...}, {number...}} The ruleset generated by scalars m, n. ---
-- @treturn[1] point [8m]P --- @param P MtPoint The base point.
-- @treturn[1] point [8n]P --- @param ruleset __TYPE_TODO The ruleset generated by scalars m, n.
-- @treturn[1] point [8m]P - [8n]P --- @return MtPoint? A [8m]P.
-- @treturn[2] nil If any of the three results is equal to O. --- @return MtPoint? B [8n]P.
-- --- @return MtPoint? C [8m]P - [8n]P.
local function prac(P, ruleset) local function prac(P, ruleset)
-- Randomize. -- Randomize.
local rf = fp.decode(random.random(32)) local rf = fp.decode(random.random(32) --[[@as String32, length is given]])
local A = {fp.mul(P[1], rf), fp.mul(P[2], rf)} local A = {fp.mul(P[1], rf), fp.mul(P[2], rf)}
-- Start the base at [8]P. -- Start the base at [8]P.
@ -184,7 +178,7 @@ local function prac(P, ruleset)
-- Reject rulesets where m = n. -- Reject rulesets where m = n.
local rules = ruleset[2] local rules = ruleset[2]
if #rules == 0 then return nil end if #rules == 0 then return end
-- Evaluate the first rule. -- Evaluate the first rule.
-- Since e = d, this means A - B = C = O. Differential addition fails when -- Since e = d, this means A - B = C = O. Differential addition fails when

View file

@ -1,30 +1,31 @@
--- Point arithmetic on the Edwards25519 Edwards curve. --- Point arithmetic on the Edwards25519 Edwards curve.
--
-- :::note Internal Module
-- This module is meant for internal use within the library. Its API is unstable
-- and subject to change without major version bumps.
-- :::
--
-- <br />
--
-- @module[kind=internal] internal.edwards25519
--
local fp = require "ccryptolib.internal.fp" local fp = require "ccryptolib.internal.fp"
local unpack = unpack or table.unpack local unpack = unpack or table.unpack
--- @class EdPoint A point on Edwards25519, in extended coordinates.
--- @field [1] number[] The X coordinate.
--- @field [2] number[] The Y coordinate.
--- @field [3] number[] The Z coordinate.
--- @field [4] number[] The T coordinate.
--- @class NsPoint A point on Edwards25519, in Niels' coordinates.
--- @field [1] number[] Preprocessed Y + X.
--- @field [2] number[] Preprocessed Y - X.
--- @field [3] number[] Preprocessed 2Z.
--- @field [4] number[] Preprocessed 2DT.
local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666))) local D = fp.mul(fp.num(-121665), fp.invert(fp.num(121666)))
local K = fp.kmul(D, 2) local K = fp.kmul(D, 2)
--- @type EdPoint
local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)} local O = {fp.num(0), fp.num(1), fp.num(1), fp.num(0)}
local G = nil local G = nil
--- Doubles a point. --- Doubles a point.
-- --- @param P1 EdPoint The point to double.
-- @tparam point P1 The point to double. --- @return EdPoint P2 P1 + P1.
-- @treturn point Twice P1.
--
local function double(P1) local function double(P1)
-- Unsoundness: fp.sub(g, e), and fp.sub(d, i) break fp.sub's contract since -- Unsoundness: fp.sub(g, e), and fp.sub(d, i) break fp.sub's contract since
-- it doesn't accept an fp2. Although not ideal, in practice this doesn't -- it doesn't accept an fp2. Although not ideal, in practice this doesn't
@ -48,14 +49,12 @@ local function double(P1)
end end
--- Adds two points. --- Adds two points.
-- --- @param P1 EdPoint The first summand point.
-- @tparam point P1 The first summand point. --- @param N2 NsPoint The second summand point.
-- @tparam niels N1 The second summand point, in Niels form. See @{niels}. --- @return EdPoint P3 P1 + P2, where N2 = niels(P2).
-- @treturn point The sum. local function add(P1, N2)
--
local function add(P1, N1)
local P1x, P1y, P1z, P1t = unpack(P1) local P1x, P1y, P1z, P1t = unpack(P1)
local N1p, N1m, N1z, N1t = unpack(N1) local N1p, N1m, N1z, N1t = unpack(N2)
local a = fp.sub(P1y, P1x) local a = fp.sub(P1y, P1x)
local b = fp.mul(a, N1m) local b = fp.mul(a, N1m)
local c = fp.add(P1y, P1x) local c = fp.add(P1y, P1x)
@ -73,9 +72,13 @@ local function add(P1, N1)
return {P3x, P3y, P3z, P3t} return {P3x, P3y, P3z, P3t}
end end
local function sub(P1, N1) --- Subtracts one point from another.
--- @param P1 EdPoint The first summand point.
--- @param N2 NsPoint The second summand point.
--- @return EdPoint P3 P1 - P2, where N2 = niels(P2).
local function sub(P1, N2)
local P1x, P1y, P1z, P1t = unpack(P1) local P1x, P1y, P1z, P1t = unpack(P1)
local N1p, N1m, N1z, N1t = unpack(N1) local N1p, N1m, N1z, N1t = unpack(N2)
local a = fp.sub(P1y, P1x) local a = fp.sub(P1y, P1x)
local b = fp.mul(a, N1p) local b = fp.mul(a, N1p)
local c = fp.add(P1y, P1x) local c = fp.add(P1y, P1x)
@ -94,10 +97,8 @@ local function sub(P1, N1)
end end
--- Computes the Niels representation of a point. --- Computes the Niels representation of a point.
-- --- @param P1 EdPoint The input point.
-- @tparam point P1 --- @return NsPoint N1 Niels' precomputation applied to P1.
-- @treturn niels P1's Niels representation.
--
local function niels(P1) local function niels(P1)
local P1x, P1y, P1z, P1t = unpack(P1) local P1x, P1y, P1z, P1t = unpack(P1)
local N3p = fp.add(P1y, P1x) local N3p = fp.add(P1y, P1x)
@ -107,6 +108,9 @@ local function niels(P1)
return {N3p, N3m, N3z, N3t} return {N3p, N3m, N3z, N3t}
end end
--- Scales a point.
--- @param P1 EdPoint The input point.
--- @return EdPoint P2 The same point as P1, but with Z = 1.
local function scale(P1) local function scale(P1)
local P1x, P1y, P1z = unpack(P1) local P1x, P1y, P1z = unpack(P1)
local zInv = fp.invert(P1z) local zInv = fp.invert(P1z)
@ -117,11 +121,9 @@ local function scale(P1)
return {P3x, P3y, P3z, P3t} return {P3x, P3y, P3z, P3t}
end end
--- Encodes a point. --- Encodes a scaled point.
-- --- @param P1 EdPoint The scaled point to encode.
-- @tparam point P1 The scaled point to encode. --- @return string out P1 encoded as a 32-byte string.
-- @treturn string The 32-byte encoded point.
--
local function encode(P1) local function encode(P1)
P1 = scale(P1) P1 = scale(P1)
local P1x, P1y = unpack(P1) local P1x, P1y = unpack(P1)
@ -131,11 +133,8 @@ local function encode(P1)
end end
--- Decodes a point. --- Decodes a point.
-- --- @param str String32 A 32-byte encoded point.
-- @tparam string str A 32-byte encoded point. --- @return EdPoint? P1 The decoded point, or nil if it isn't on the curve.
-- @treturn[1] point The decoded point.
-- @treturn[2] nil If the string did not represent a valid encoded point.
--
local function decode(str) local function decode(str)
local P3y = fp.decode(str) local P3y = fp.decode(str)
local a = fp.square(P3y) local a = fp.square(P3y)
@ -153,8 +152,12 @@ local function decode(str)
return {P3x, P3y, P3z, P3t} return {P3x, P3y, P3z, P3t}
end end
G = decode("Xfffffffffffffffffffffffffffffff") G = decode("Xfffffffffffffffffffffffffffffff") --[[@as EdPoint, G is valid]]
--- Transforms little-endian bits into a signed radix-2^w form.
--- @param bits number[]
--- @param w number Log2 of the radix, must be at least 1.
--- @return number[]
local function signedRadixW(bits, w) local function signedRadixW(bits, w)
-- TODO Find a more elegant way of doing this. -- TODO Find a more elegant way of doing this.
local wPow = 2 ^ w local wPow = 2 ^ w
@ -176,6 +179,10 @@ local function signedRadixW(bits, w)
return out return out
end end
--- Computes a multiplication table for radix-2^w form multiplication.
--- @param P EdPoint The base point.
--- @param w number Log2 of the radix, must be at least 1.
--- @return NsPoint[][]
local function radixWTable(P, w) local function radixWTable(P, w)
local out = {} local out = {}
for i = 1, math.ceil(256 / w) do for i = 1, math.ceil(256 / w) do
@ -190,10 +197,21 @@ local function radixWTable(P, w)
return out return out
end end
--- The radix logarithm of the precomputed table for G.
local G_W = 5 local G_W = 5
--- The precomputed multiplication table for G.
local G_TABLE = radixWTable(G, G_W) local G_TABLE = radixWTable(G, G_W)
local function WNAF(bits, w) --- Transforms little-endian bits into a signed radix-2^w non-adjacent form.
---
--- The returned array contains a 0 whenever a single doubling is needed, or an
--- odd integer when an addition with a multiple of the base is needed.
---
--- @param bits number[]
--- @param w number Log2 of the radix, must be at least 1.
--- @return number[]
local function wNaf(bits, w)
-- TODO Find a more elegant way of doing this. -- TODO Find a more elegant way of doing this.
local wPow = 2 ^ w local wPow = 2 ^ w
local wPowh = wPow / 2 local wPowh = wPow / 2
@ -220,6 +238,10 @@ local function WNAF(bits, w)
return out return out
end end
--- Computes a multiplication table for wNAF form multiplication.
--- @param P EdPoint The base point.
--- @param w number Log2 of the radix, must be at least 1.
--- @return NsPoint[]
local function WNAFTable(P, w) local function WNAFTable(P, w)
local dP = double(P) local dP = double(P)
local out = {niels(P)} local out = {niels(P)}
@ -230,10 +252,8 @@ local function WNAFTable(P, w)
end end
--- Performs a scalar multiplication by the base point G. --- Performs a scalar multiplication by the base point G.
-- --- @param bits number[] The scalar multiplicand little-endian bits.
-- @tparam {number...} bits The scalar multiplier, in little-endian bits. --- @return EdPoint
-- @treturn point The product.
--
local function mulG(bits) local function mulG(bits)
local sw = signedRadixW(bits, G_W) local sw = signedRadixW(bits, G_W)
local R = O local R = O
@ -249,13 +269,11 @@ local function mulG(bits)
end end
--- Performs a scalar multiplication operation. --- Performs a scalar multiplication operation.
-- --- @param P EdPoint The base point.
-- @tparam point P The base point. --- @param bits number[] The scalar multiplicand little-endian bits.
-- @tparam {number...} bits The scalar multiplier, in little-endian bits. --- @return EdPoint
-- @treturn point The product.
--
local function mul(P, bits) local function mul(P, bits)
local naf = WNAF(bits, 5) local naf = wNaf(bits, 5)
local tbl = WNAFTable(P, 5) local tbl = WNAFTable(P, 5)
local R = O local R = O
for i = #naf, 1, -1 do for i = #naf, 1, -1 do

View file

@ -1,21 +1,19 @@
--- Arithmetic on Curve25519's base field. --- Arithmetic on Curve25519's base field.
--
-- :::note Internal Module
-- This module is meant for internal use within the library. Its API is unstable
-- and subject to change without major version bumps.
-- :::
--
-- <br />
--
-- @module[kind=internal] internal.fp
--
local packing = require "ccryptolib.internal.packing" local packing = require "ccryptolib.internal.packing"
local unpack = unpack or table.unpack local unpack = unpack or table.unpack
local ufp, fmtfp = packing.compileUnpack("<I3I3I2I3I3I2I3I3I2I3I3I2") local ufp, fmtfp = packing.compileUnpack("<I3I3I2I3I3I2I3I3I2I3I3I2")
--- @class Fq An element of the field of integers modulo 2²⁵⁵ - 19.
--- @class FpR2: Fq An Fp element with limbs inside twice the standard range.
--- @class FpR1: FpR2 An Fp element with limbs inside the standard range. See
--- the Curve25519 polynomial representation for more info around this.
--- The modular square root of -1. --- The modular square root of -1.
--- @type FpR1
local I = { local I = {
0958640 * 2 ^ 0, 0958640 * 2 ^ 0,
0826664 * 2 ^ 22, 0826664 * 2 ^ 22,
@ -32,19 +30,15 @@ local I = {
} }
--- Converts a Lua number to an element. --- Converts a Lua number to an element.
-- --- @param n number A number n in [0..2²²).
-- @tparam number n A number n in [0..2²²). --- @return FpR1 out The number as an element.
-- @treturn fp1
--
local function num(n) local function num(n)
return {n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} return {n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
end end
--- Negates an element. --- Negates an element.
-- --- @param a FpR1
-- @tparam fp1 a --- @return FpR1 out -a.
-- @treturn fp1 -a.
--
local function neg(a) local function neg(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
return { return {
@ -64,11 +58,9 @@ local function neg(a)
end end
--- Adds two elements. --- Adds two elements.
-- --- @param a FpR1
-- @tparam fp1 a --- @param b FpR1
-- @tparam fp1 b --- @return FpR2 out a + b.
-- @treturn fp2 a + b.
--
local function add(a, b) local function add(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
@ -89,11 +81,9 @@ local function add(a, b)
end end
--- Subtracts an element from another. --- Subtracts an element from another.
-- --- @param a FpR1
-- @tparam fp1 a --- @param b FpR1
-- @tparam fp1 b --- @return FpR2 out a - b.
-- @treturn fp2 a - b.
--
local function sub(a, b) local function sub(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
@ -113,13 +103,9 @@ local function sub(a, b)
} }
end end
--- Carries an element. --- Carries an element. Also performs a small reduction modulo p.
-- --- @param a FpR2 The element to carry.
-- Also performs a small reduction modulo p. --- @return FpR1 out The same element as a but in a tighter range.
--
-- @tparam fp2 a
-- @treturn fp1 a' ≡ a (mod p).
--
local function carry(a) local function carry(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11
@ -157,14 +143,14 @@ local function carry(a)
end end
--- Returns the canoncal representative of a modp number. --- Returns the canoncal representative of a modp number.
-- ---
-- Some elements can be represented by two different arrays of floats. This --- Some elements can be represented by two different arrays of floats. This
-- returns the canonical element of the represented equivalence class. We define --- returns the canonical element of the represented equivalence class. We
-- an element as canonical if it's the smallest nonnegative number in its class. --- define an element as canonical if it's the smallest nonnegative number in
-- --- its class.
-- @tparam fp2 a ---
-- @treturn fp1 A canonical element a' ≡ a (mod p). --- @param a FpR2
-- --- @return FpR1 out A canonical element a' ≡ a (mod p).
local function canonicalize(a) local function canonicalize(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11
@ -205,11 +191,9 @@ local function canonicalize(a)
end end
--- Returns whether two elements are the same. --- Returns whether two elements are the same.
-- --- @param a FpR1
-- @tparam fp1 a --- @param b FpR1
-- @tparam fp1 b --- @return boolean eq Whether a ≡ b (mod p).
-- @treturn boolean Whether the two elements are the same mod p.
--
local function eq(a, b) local function eq(a, b)
local c = canonicalize(sub(a, b)) local c = canonicalize(sub(a, b))
for i = 1, 12 do if c[i] ~= 0 then return false end end for i = 1, 12 do if c[i] ~= 0 then return false end end
@ -217,11 +201,9 @@ local function eq(a, b)
end end
--- Multiplies two elements. --- Multiplies two elements.
-- --- @param a FpR2
-- @tparam fp2 a --- @param b FpR2
-- @tparam fp2 b --- @return FpR1 c An element such that c ≡ a × b (mod p).
-- @treturn fp1 c ≡ a × b (mod p).
--
local function mul(a, b) local function mul(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11 = unpack(b)
@ -421,10 +403,8 @@ local function mul(a, b)
end end
--- Squares an element. --- Squares an element.
-- --- @param a FpR2
-- @tparam fp2 a --- @return FpR1 b An element such that b ≡ a² (mod p).
-- @treturn fp1 c ≡ a² (mod p).
--
local function square(a) local function square(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local d00, d01, d02, d03, d04, d05, d06, d07, d08, d09, d10 local d00, d01, d02, d03, d04, d05, d06, d07, d08, d09, d10
@ -571,11 +551,9 @@ local function square(a)
end end
--- Multiplies an element by a number. --- Multiplies an element by a number.
-- --- @param a FpR2
-- @tparam fp2 a --- @param k number A number in [0..2²²).
-- @tparam number k A number k in [0..2²²). --- @return FpR1 c An element such that c ≡ a × k (mod p).
-- @treturn fp1 c ≡ a × k (mod p).
--
local function kmul(a, k) local function kmul(a, k)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11
@ -627,24 +605,20 @@ local function kmul(a, k)
end end
--- Squares an element n times. --- Squares an element n times.
-- --- @param a FpR2
-- @tparam fp2 a --- @param n number The number of times to square a.
-- @tparam number n A positive integer. --- @return FpR1 c A number c such that c ≡ a ^ 2 ^ n (mod p).
-- @treturn fp1 c ≡ a ^ 2 ^ n (mod p).
--
local function nsquare(a, n) local function nsquare(a, n)
for _ = 1, n do a = square(a) end for _ = 1, n do a = square(a) end
return a return a
end end
--- Computes the inverse of an element. --- Computes the inverse of an element.
-- ---
-- Computation of the inverse requires 11 multiplications and 252 squarings. --- Performance: 11 multiplications and 252 squarings.
-- ---
-- @tparam fp2 a --- @param a FpR2
-- @treturn[1] fp1 c ≡ a⁻¹ (mod p), if a ≠ 0. --- @return FpR1 c An element such that c ≡ a⁻¹ (mod p), or 0 if c doesn't exist.
-- @treturn[2] fp1 c ≡ 0 (mod p), if a = 0.
--
local function invert(a) local function invert(a)
local a2 = square(a) local a2 = square(a)
local a9 = mul(a, nsquare(a2, 2)) local a9 = mul(a, nsquare(a2, 2))
@ -662,15 +636,13 @@ local function invert(a)
return mul(nsquare(x250, 5), a11) return mul(nsquare(x250, 5), a11)
end end
--- Returns an element x that satisfies v * x² = u. --- Returns an element x that satisfies vx² = u.
-- ---
-- Note that when v = 0, the returned element can take any value. --- Note that when v = 0, the returned element can take any value.
-- ---
-- @tparam fp2 u --- @param u FpR2
-- @tparam fp2 v --- @param v FpR2
-- @treturn[1] fp1 x. --- @return FpR1? x An element such that vx² ≡ u (mod p), if it exists.
-- @treturn[2] nil if there is no solution.
--
local function sqrtDiv(u, v) local function sqrtDiv(u, v)
u = carry(u) u = carry(u)
@ -711,11 +683,11 @@ local function sqrtDiv(u, v)
end end
end end
--- @class String32: string A string with length equal to 32 bytes.
--- Encodes an element in little-endian. --- Encodes an element in little-endian.
-- --- @param a FpR1
-- @tparam fp2 a --- @return String32 out The 32-byte canonical encoding of a.
-- @treturn string A 32-byte string. Always represents the canonical element.
--
local function encode(a) local function encode(a)
a = canonicalize(a) a = canonicalize(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10, a11 = unpack(a)
@ -744,14 +716,12 @@ local function encode(a)
putBytes(3) acc = acc + a11 / 2 ^ 232 putBytes(3) acc = acc + a11 / 2 ^ 232
putBytes(3) putBytes(3)
return string.char(unpack(bytes)) return string.char(unpack(bytes)) --[[@as String32, putBytes sums to 32]]
end end
--- Decodes an element in little-endian. --- Decodes an element in little-endian.
-- --- @param b String32 A 32-byte string, the most-significant bit is discarded.
-- @tparam string b A 32-byte string. The most-significant bit is discarded. --- @return FpR1 out The decoded element. It may not be canonical.
-- @treturn fp1 The decoded element. May not be canonical.
--
local function decode(b) local function decode(b)
local w00, w01, w02, w03, w04, w05, w06, w07, w08, w09, w10, w11 = local w00, w01, w02, w03, w04, w05, w06, w07, w08, w09, w10, w11 =
ufp(fmtfp, b, 1) ufp(fmtfp, b, 1)
@ -774,11 +744,9 @@ local function decode(b)
} }
end end
--- Checks if two elements are equal. --- Checks if the given element is equal to 0.
-- --- @param a FpR2
-- @tparam fp2 a --- @return boolean eqz Whether a ≡ 0 (mod p).
-- @treturn boolean Whether a ≡ 0 (mod p).
--
local function eqz(a) local function eqz(a)
local c = canonicalize(a) local c = canonicalize(a)
local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 = unpack(c) local c00, c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11 = unpack(c)

View file

@ -10,8 +10,8 @@
-- @module[kind=internal] internal.fq -- @module[kind=internal] internal.fq
-- --
local mp = require "ccryptolib.internal.mp" local mp = require "ccryptolib.internal.mp"
local util = require "ccryptolib.internal.util" local util = require "ccryptolib.internal.util"
local packing = require "ccryptolib.internal.packing" local packing = require "ccryptolib.internal.packing"
local unpack = unpack or table.unpack local unpack = unpack or table.unpack
@ -90,14 +90,12 @@ local function reduce(a)
local c = mp.sub(a, Q) local c = mp.sub(a, Q)
-- Return carry(a) if a < q. -- Return carry(a) if a < q.
if mp.approx(c) < 0 then return mp.carry(a) end if mp.approx(c) < 0 then return (mp.carry(a)) end
-- c >= q means c - q >= 0. -- c >= q means c - q >= 0.
-- Since q < 2²⁸⁸, c < 2q means c - q < q < 2²⁸⁸. -- Since q < 2²⁸⁸, c < 2q means c - q < q < 2²⁸⁸.
-- c's limbs fit in (-2²⁶..2²⁶), since subtraction adds at most one bit. -- c's limbs fit in (-2²⁶..2²⁶), since subtraction adds at most one bit.
local cc = mp.carry(c) return (mp.carry(c)) -- cc < q implies that the carry number is 0.
cc[12] = nil -- cc < q implies that cc[12] = 0.
return cc
end end
--- Adds two scalars mod q. --- Adds two scalars mod q.
@ -170,15 +168,6 @@ local function demontgomery(a)
return reduce(s1) return reduce(s1)
end end
--- Converts a Lua number to a scalar.
--
-- @tparam number n A number n in [0..2²⁴).
-- @treturn {number...} 2²⁶⁴ × n mod q as 11 limbs in [0..2²⁴).
--
local function num(n)
return montgomery({n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
end
--- Encodes a scalar. --- Encodes a scalar.
-- --
-- @tparam {number...} a A number 2²⁶⁴ × a mod q as 11 limbs in [0..2²⁴). -- @tparam {number...} a A number 2²⁶⁴ × a mod q as 11 limbs in [0..2²⁴).
@ -378,12 +367,8 @@ local function makeRuleset(a, b)
end end
return { return {
num = num,
add = add, add = add,
neg = neg,
sub = sub, sub = sub,
montgomery = montgomery,
demontgomery = demontgomery,
mul = mul, mul = mul,
encode = encode, encode = encode,
decode = decode, decode = decode,

View file

@ -12,12 +12,19 @@
local unpack = unpack or table.unpack local unpack = unpack or table.unpack
--- A little-endian big integer of width 11 in (-2⁵²..2⁵²).
--- @class MpSW11L52
--- A little-endian big integer of width 11 in (-2²⁴, 2²⁴).
--- @class MpSW11L24: MpSW11L52
--- A little-endian big integer of width 11 in [0..2²⁴).
--- @class MpUW11L24: MpSW11L24
--- Carries a number in base 2²⁴ into a signed limb form. --- Carries a number in base 2²⁴ into a signed limb form.
-- --- @param a MpSW11L52
-- @tparam {number...} a A number a in (-2²⁸⁸..2²⁸⁸) as 11 limbs in --- @return MpSW11L24 low The carried low limbs.
-- [-2⁵²..2⁵²]. --- @return number carry The overflowed carry.
-- @treturn {number...} a as 12 limbs in (-2²⁴..2²⁴).
--
local function carryWeak(a) local function carryWeak(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
@ -45,16 +52,13 @@ local function carryWeak(a)
a08 - h08, a08 - h08,
a09 - h09, a09 - h09,
a10 - h10, a10 - h10,
h10 * 2 ^ -24, }, h10 * 2 ^ -24
}
end end
--- Carries a number in base 2²⁴. --- Carries a number in base 2²⁴.
-- --- @param a MpSW11L52
-- @tparam {number...} a A number a in [0..2²⁸⁸) as 11 limbs in --- @return MpUW11L24 low The low 11 limbs of the output.
-- [-2⁵²..2⁵²]. --- @return number carry The overflow carry.
-- @treturn {number...} a as 12 limbs in [0..2²⁴).
--
local function carry(a) local function carry(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
@ -71,15 +75,13 @@ local function carry(a)
local l10 = a10 % 2 ^ 24 local l10 = a10 % 2 ^ 24
local h10 = (a10 - l10) * 2 ^ -24 local h10 = (a10 - l10) * 2 ^ -24
return {l00, l01, l02, l03, l04, l05, l06, l07, l08, l09, l10, h10} return {l00, l01, l02, l03, l04, l05, l06, l07, l08, l09, l10}, h10
end end
--- Adds two numbers. --- Adds two numbers.
-- --- @param a MpSW11L24
-- @tparam {number...} a An array of 11 limbs in (k..l). --- @param b MpSW11L24
-- @tparam {number...} b An array of 11 limbs in (m..n). --- @return MpSW11L52 c a + b
-- @treturn {number...} a + b as 11 limbs in ((k + m)..(l + n)).
--
local function add(a, b) local function add(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10 = unpack(b)
@ -100,11 +102,9 @@ local function add(a, b)
end end
--- Subtracts a number from another. --- Subtracts a number from another.
-- --- @param a MpSW11L24
-- @tparam {number...} a An array of 11 limbs in (k..l). --- @param b MpSW11L24
-- @tparam {number...} b An array of 11 limbs in (m..n). --- @return MpSW11L52 c a - b
-- @treturn {number...} a + b as 11 limbs in ((k - m)..(l - n)).
--
local function sub(a, b) local function sub(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10 = unpack(b)
@ -125,17 +125,15 @@ local function sub(a, b)
end end
--- Computes the lower half of a product between two numbers. --- Computes the lower half of a product between two numbers.
-- --- @param a MpUW11L24
-- @tparam {number...} a A nonnegative integer as 11 limbs in [0..2²⁴). --- @param b MpUW11L24
-- @tparam {number...} b A nonnegative integer as 11 limbs in [0..2²⁴). --- @return MpUW11L24 c a × b (mod 2²⁶⁴)
-- @treturn {number...} c ≡ a × b (mod 2²⁶⁴) as 11 limbs in [0..2²⁴). --- @return number carry ⌊a × b ÷ 2²⁶⁴⌋
-- @treturn number ⌊a × b ÷ 2²⁶⁴⌋.
--
local function lmul(a, b) local function lmul(a, b)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10 = unpack(b) local b00, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10 = unpack(b)
local out = carry { return carry {
a00 * b00, a00 * b00,
a01 * b00 + a00 * b01, a01 * b00 + a00 * b01,
a02 * b00 + a01 * b01 + a00 * b02, a02 * b00 + a01 * b01 + a00 * b02,
@ -148,28 +146,21 @@ local function lmul(a, b)
a09 * b00 + a08 * b01 + a07 * b02 + a06 * b03 + a05 * b04 + a04 * b05 + a03 * b06 + a02 * b07 + a01 * b08 + a00 * b09, a09 * b00 + a08 * b01 + a07 * b02 + a06 * b03 + a05 * b04 + a04 * b05 + a03 * b06 + a02 * b07 + a01 * b08 + a00 * b09,
a10 * b00 + a09 * b01 + a08 * b02 + a07 * b03 + a06 * b04 + a05 * b05 + a04 * b06 + a03 * b07 + a02 * b08 + a01 * b09 + a00 * b10, a10 * b00 + a09 * b01 + a08 * b02 + a07 * b03 + a06 * b04 + a05 * b05 + a04 * b06 + a03 * b07 + a02 * b08 + a01 * b09 + a00 * b10,
} }
-- Strip overflow.
local of = out[12]
out[12] = nil
return out, of
end end
--- Computes the a product between two numbers. --- Computes the a product between two numbers.
-- --- @param a MpUW11L24
-- @tparam {number...} a An array of 11 limbs in [0..2²⁴). --- @param b MpUW11L24
-- @tparam {number...} b An array of 11 limbs in [0..2²⁴). --- @return MpUW11L24 low The low 11 limbs of a × b.
-- @treturn {number...} The first 11 limbs of a × b in [0..2²⁴). --- @return MpUW11L24 high The high 11 limbs of a × b.
-- @treturn {number...} The last 11 limbs of a × b in [0..2²⁴).
--
local function mul(a, b) local function mul(a, b)
local low, of = lmul(a, b) local low, of = lmul(a, b)
local _, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local _, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
local _, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10 = unpack(b) local _, b01, b02, b03, b04, b05, b06, b07, b08, b09, b10 = unpack(b)
local high = carry { -- The carry is always 0.
return low, (carry {
of + a10 * b01 + a09 * b02 + a08 * b03 + a07 * b04 + a06 * b05 + a05 * b06 + a04 * b07 + a03 * b08 + a02 * b09 + a01 * b10, of + a10 * b01 + a09 * b02 + a08 * b03 + a07 * b04 + a06 * b05 + a05 * b06 + a04 * b07 + a03 * b08 + a02 * b09 + a01 * b10,
a10 * b02 + a09 * b03 + a08 * b04 + a07 * b05 + a06 * b06 + a05 * b07 + a04 * b08 + a03 * b09 + a02 * b10, a10 * b02 + a09 * b03 + a08 * b04 + a07 * b05 + a06 * b06 + a05 * b07 + a04 * b08 + a03 * b09 + a02 * b10,
a10 * b03 + a09 * b04 + a08 * b05 + a07 * b06 + a06 * b07 + a05 * b08 + a04 * b09 + a03 * b10, a10 * b03 + a09 * b04 + a08 * b05 + a07 * b06 + a06 * b07 + a05 * b08 + a04 * b09 + a03 * b10,
@ -181,40 +172,31 @@ local function mul(a, b)
a10 * b09 + a09 * b10, a10 * b09 + a09 * b10,
a10 * b10, a10 * b10,
0 0
} })
-- Strip overflow (it's always 0).
high[12] = nil
return low, high
end end
--- Computes a double-width sum of two numbers. --- Computes a double-width sum of two numbers.
-- --- @param a0 MpUW11L24 The low 11 limbs of a.
-- @tparam {number...} a0 The low part of a as 11 limbs in [0..2²⁴). --- @param a1 MpUW11L24 The high 11 limbs of a.
-- @tparam {number...} a1 The high part of a as 11 limbs in [0..2²⁴). --- @param b0 MpUW11L24 The low 11 limbs of b.
-- @tparam {number...} b0 The low part of b as 11 limbs in [0..2²⁴). --- @param b1 MpUW11L24 The high 11 limbs of b.
-- @tparam {number...} b1 The high part of b as 11 limbs in [0..2²⁴). --- @return MpUW11L24 c0 The low 11 limbs of a + b.
-- @treturn {number...} The low part of a + b as 11 limbs in [0..2²⁴). --- @return MpUW11L24 c1 The high 11 limbs of a + b.
-- @treturn {number...} The high part of a + b as 12 limbs in [0..2²⁴). --- @return number The carry.
--
local function dwadd(a0, a1, b0, b1) local function dwadd(a0, a1, b0, b1)
local low = carry(add(a0, b0)) local low, c = carry(add(a0, b0))
local high = add(a1, b1) local high = add(a1, b1)
high[1] = high[1] + low[12] high[1] = high[1] + c
low[12] = nil
return low, carry(high) return low, carry(high)
end end
--- Computes half of a number. --- Computes half of a number.
-- --- @param a MpSW11L24 The number to halve, must be even.
-- @tparam {number...} a An even positive integer as 11 limbs in (-2²⁴..2²⁴). --- @return MpSW11L24 c a ÷ 2
-- @treturn {number...} a ÷ 2 as 11 limbs in (-2²⁴..2²⁴).
--
local function half(a) local function half(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
local out = carryWeak { return (carryWeak {
a00 * 0.5 + a01 * 2 ^ 23, a00 * 0.5 + a01 * 2 ^ 23,
a02 * 2 ^ 23, a02 * 2 ^ 23,
a03 * 2 ^ 23, a03 * 2 ^ 23,
@ -226,18 +208,12 @@ local function half(a)
a09 * 2 ^ 23, a09 * 2 ^ 23,
a10 * 2 ^ 23, a10 * 2 ^ 23,
0, 0,
} })
out[12] = nil
return out
end end
--- Computes a third of a number. --- Computes a third of a number.
-- --- @param a MpSW11L24 The number to divide, must be a multiple of 3.
-- @tparam {number...} a A positive multiple of 3 as 11 limbs in (-2²⁶..2²⁶). --- @return MpSW11L24 c a ÷ 3
-- @treturn {number...} a ÷ 3 as 11 limbs in (-2²⁴..2²⁴).
--
local function third(a) local function third(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
@ -253,7 +229,9 @@ local function third(a)
local d09 = a09 * 0xaaaaaa + d08 local d09 = a09 * 0xaaaaaa + d08
local d10 = a10 * 0xaaaaaa + d09 local d10 = a10 * 0xaaaaaa + d09
local out = carryWeak { -- We compute the modular division mod 2²⁶⁴. The carry isn't 0 but it isn't
-- part of a ÷ 3 either.
return (carryWeak {
a00 + d00, a00 + d00,
a01 + d01, a01 + d01,
a02 + d02, a02 + d02,
@ -265,39 +243,27 @@ local function third(a)
a08 + d08, a08 + d08,
a09 + d09, a09 + d09,
a10 + d10, a10 + d10,
} })
-- We compute the modular division mod 2²⁶⁴. out[12] isn't 0 but it's not
-- part of a ÷ 3 either.
out[12] = nil
return out
end end
--- Computes a number modulo 2. --- Computes a number modulo 2.
-- --- @param a MpSW11L24
-- @tparam {number...} a A number as 11 limbs in (-2²⁶, 2²⁶). --- @return number c a mod 2.
-- @treturn number a mod 2.
--
local function mod2(a) local function mod2(a)
return a[1] % 2 return a[1] % 2
end end
--- Computes a number modulo 3. --- Computes a number modulo 3.
-- --- @param a MpSW11L24
-- @tparam {number...} a A number as 11 limbs in (-2²⁶, 2²⁶). --- @return number c a mod 3.
-- @treturn number a mod 3.
--
local function mod3(a) local function mod3(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
return (a00 + a01 + a02 + a03 + a04 + a05 + a06 + a07 + a08 + a09 + a10) % 3 return (a00 + a01 + a02 + a03 + a04 + a05 + a06 + a07 + a08 + a09 + a10) % 3
end end
--- Computes a double representing the most-significant bits of a number. --- Computes a double representing the most-significant bits of a number.
-- --- @param a MpSW11L52
-- @tparam {number...} a A number as 11 limbs in (-2⁴⁸..2⁴⁸). --- @return number c A floating-point approximation for the value of a.
-- @treturn number A floating-point approximation for the value of a.
--
local function approx(a) local function approx(a)
local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a) local a00, a01, a02, a03, a04, a05, a06, a07, a08, a09, a10 = unpack(a)
return a00 return a00
@ -314,11 +280,9 @@ local function approx(a)
end end
--- Compares two numbers for ordering. --- Compares two numbers for ordering.
-- --- @param a MpSW11L24
-- @tparam {number...} a A number as 11 limbs in (-2²⁵..2²⁵). --- @param b MpSW11L24
-- @tparam {number...} b A number as 11 limbs in (-2²⁵..2²⁵). --- @return number ord Some number with ord < 0 iff a < b and ord = 0 iff a = b.
-- @treturn number Some number x with x < 0 iff a < b and x = 0 iff a = b.
--
local function cmp(a, b) local function cmp(a, b)
return approx(sub(a, b)) return approx(sub(a, b))
end end

View file

@ -1,22 +1,11 @@
--- High-performance binary packing of integers. --- High-performance binary packing of integers.
-- ---
-- :::note Internal Module --- Remark (and warning):
-- This module is meant for internal use within the library. Its API is unstable --- For performance reasons, **the generated functions do not check types,
-- and subject to change without major version bumps. --- lengths, nor ranges**. You must ensure that the passed arguments are
-- ::: --- well-formed and respect the format string yourself.
-- ---
-- <br /> --- <br />
--
-- :::warning
-- For performance reasons, **the generated functions do not check types,
-- lengths, nor ranges**. You must ensure that the passed arguments are
-- well-formed and respect the format string yourself.
-- :::
--
-- <br />
--
-- @module[kind=internal] internal.packing
--
local fmt = string.format local fmt = string.format
@ -119,14 +108,17 @@ if not string.pack or pcall(string.dump, string.pack) then
local packCache = {} local packCache = {}
local unpackCache = {} local unpackCache = {}
--- (`string.pack == nil`) Compiles a binary packing function. -- I CAN'T EVEN WITH THIS EXTENSION, WHY CAN'T IT HANDLE MORE THAN A SINGLE
-- @tparam string fmt A string matched by `^([><])I[I%d]+$`. -- LINE OF RETURN DESCRIPTION? LOOK AT IT!!! THE COMMENT GOES OVER THERE ------------------------------------------------------------------> look! ↓ ↓ ↓
-- @treturn function A high-performance function that behaves like an unsafe
-- version of `string.pack` for the given format string. Note that the third --- (string.pack is nil) Compiles a binary packing function.
-- argument isn't optional. ---
-- @treturn string fmt --- Errors if the format string is invalid or has an invalid integral size,
-- @throws If the string is invalid or has an invalid integral size. --- or if the compiled function turns out too large.
-- @throws If the compiled function is too large. ---
--- @param fmt string A string matched by `^([><])I[I%d]+$`.
--- @return fun(_ignored: any, ...: any): string pack A function that behaves like an unsafe version of `string.pack` for the given format string.
--- @return string fmt
function mod.compilePack(fmt) function mod.compilePack(fmt)
if not packCache[fmt] then if not packCache[fmt] then
packCache[fmt] = compile(fmt, mkPack) packCache[fmt] = compile(fmt, mkPack)
@ -134,13 +126,14 @@ if not string.pack or pcall(string.dump, string.pack) then
return packCache[fmt], fmt return packCache[fmt], fmt
end end
--- (`string.pack == nil`) Compiles a binary unpacking function. --- (string.pack is nil) Compiles a binary unpacking function.
-- @tparam string fmt A string matched by `^([><])I[I%d]+$`. ---
-- @treturn function A high-performance function that behaves like an unsafe --- Errors if the format string is invalid or has an invalid integral size,
-- version of `string.unpack` for the given format string. --- or if the compiled function turns out too large.
-- @treturn string fmt ---
-- @throws If the string is invalid or has an invalid integral size. --- @param fmt string A string matched by `^([><])I[I%d]+$`.
-- @throws If the compiled function is too large. --- @return fun(_ignored: any, str: string, pos: number) unpack A function that behaves like an unsafe version of `string.unpack` for the given format string. Note that the third argument isn't optional.
--- @return string fmt
function mod.compileUnpack(fmt) function mod.compileUnpack(fmt)
if not unpackCache[fmt] then if not unpackCache[fmt] then
unpackCache[fmt] = compile(fmt, mkUnpack) unpackCache[fmt] = compile(fmt, mkUnpack)
@ -150,16 +143,16 @@ if not string.pack or pcall(string.dump, string.pack) then
return mod return mod
else else
--- (`string.pack ~= nil`) Compiles a binary packing function. --- (string.pack isn't nil) It's string.pack! It returns string.pack!
-- @tparam string fmt --- @param fmt string
-- @treturn function `string.pack` --- @return fun(fmt: string, ...: any): string pack string.pack!
-- @treturn string fmt --- @return string fmt
mod.compilePack = function(fmt) return string.pack, fmt end mod.compilePack = function(fmt) return string.pack, fmt end
--- (`string.pack ~= nil`) Compiles a binary unpacking function. --- (string.pack isn't nil) It's string.unpack! It returns string.unpack!
-- @tparam string fmt --- @param fmt string
-- @treturn function `string.unpack` --- @return fun(fmt: string, str: string, pos: number) unpack string.unpack!
-- @treturn string fmt --- @return string fmt
mod.compileUnpack = function(fmt) return string.unpack, fmt end mod.compileUnpack = function(fmt) return string.unpack, fmt end
end end

View file

@ -1,14 +1,4 @@
--- The SHA512 cryptographic hash function. --- The SHA512 cryptographic hash function.
--
-- :::note Internal Module
-- This module is meant for internal use within the library. Its API is unstable
-- and subject to change without major version bumps.
-- :::
--
-- <br />
--
-- @module[kind=internal] internal.sha512
--
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local packing = require "ccryptolib.internal.packing" local packing = require "ccryptolib.internal.packing"
@ -59,10 +49,8 @@ local K = {
} }
--- Hashes data bytes using SHA512. --- Hashes data bytes using SHA512.
-- --- @param data string The input data.
-- @tparam string data The input data. --- @return string hash The 64-byte hash value.
-- @treturn string The 64-byte hash value.
--
local function digest(data) local function digest(data)
expect(1, data, "string") expect(1, data, "string")

View file

@ -4,12 +4,10 @@ local function lassert(val, err, level)
end end
--- Converts a little-endian array from one power-of-two base to another. --- Converts a little-endian array from one power-of-two base to another.
-- --- @param a number[] The array to convert, in little-endian.
-- @tparam {number...} a The array to convert, in little-endian. --- @param base1 number The base to convert from. Must be a power of 2.
-- @tparam number base1 The base to convert from. Must be a power of 2. --- @param base2 number The base to convert to. Must be a power of 2.
-- @tparam number base2 The base to convert to. Must be a power of 2. --- @return number[]
-- @treturn {number...}
--
local function rebaseLE(a, base1, base2) -- TODO Write contract properly. local function rebaseLE(a, base1, base2) -- TODO Write contract properly.
local out = {} local out = {}
local outlen = 1 local outlen = 1
@ -33,10 +31,8 @@ local function rebaseLE(a, base1, base2) -- TODO Write contract properly.
end end
--- Decodes bits with X25519/Ed25519 exponent clamping. --- Decodes bits with X25519/Ed25519 exponent clamping.
-- --- @param str string The 32-byte encoded exponent.
-- @taparm string str The 32-byte encoded exponent. --- @return number[] bits The decoded clamped bits.
-- @treturn {number...} The decoded clamped bits.
--
local function bits(str) local function bits(str)
-- Decode. -- Decode.
local bytes = {str:byte(1, 32)} local bytes = {str:byte(1, 32)}
@ -61,10 +57,8 @@ local function bits(str)
end end
--- Decodes bits with X25519/Ed25519 exponent clamping and division by 8. --- Decodes bits with X25519/Ed25519 exponent clamping and division by 8.
-- --- @param str string The 32-byte encoded exponent.
-- @taparm string str The 32-byte encoded exponent. --- @return number[] bits The decoded clamped bits, divided by 8.
-- @treturn {number...} The decoded clamped bits, divided by 8.
--
local function bits8(str) local function bits8(str)
return {unpack(bits(str), 4)} return {unpack(bits(str), 4)}
end end

View file

@ -1,7 +1,4 @@
--- The Poly1305 one-time authenticator. --- The Poly1305 one-time authenticator.
--
-- @module poly1305
--
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local lassert = require "ccryptolib.internal.util".lassert local lassert = require "ccryptolib.internal.util".lassert
@ -13,11 +10,9 @@ local p4x4 = packing.compilePack(fmt4x4)
local mod = {} local mod = {}
--- Computes a Poly1305 message authentication code. --- Computes a Poly1305 message authentication code.
-- --- @param key string A 32-byte single-use random key.
-- @tparam string key A 32-byte single-use random key. --- @param message string The message to authenticate.
-- @tparam string message The message to authenticate. --- @return string tag The 16-byte authentication tag.
-- @treturn string The 16-byte authentication tag.
--
function mod.mac(key, message) function mod.mac(key, message)
expect(1, key, "string") expect(1, key, "string")
lassert(#key == 32, "key length must be 32", 2) lassert(#key == 32, "key length must be 32", 2)

View file

@ -22,9 +22,7 @@ local initialized = false
local mod = {} local mod = {}
--- Mixes entropy into the generator, and marks it as initialized. --- Mixes entropy into the generator, and marks it as initialized.
-- --- @param seed string The seed data.
-- @tparam string seed The seed data.
--
function mod.init(seed) function mod.init(seed)
expect(1, seed, "string") expect(1, seed, "string")
state = blake3.digestKeyed(state, seed) state = blake3.digestKeyed(state, seed)
@ -32,18 +30,14 @@ function mod.init(seed)
end end
--- Mixes extra entropy into the generator state. --- Mixes extra entropy into the generator state.
-- --- @param data string The additional entropy to mix.
-- @tparam string seed The additional entropy to mix.
--
function mod.mix(data) function mod.mix(data)
state = blake3.digestKeyed(state, data) state = blake3.digestKeyed(state, data)
end end
--- Generates random bytes. --- Generates random bytes.
-- --- @param len number The desired output length.
-- @tparam number len The desired output length. --- @return string bytes
-- @throws If the generator hasn't been initialized.
--
function mod.random(len) function mod.random(len)
expect(1, len, "number") expect(1, len, "number")
lassert(initialized, "attempt to use an uninitialized random generator", 2) lassert(initialized, "attempt to use an uninitialized random generator", 2)

View file

@ -1,7 +1,4 @@
--- The SHA256 cryptographic hash function. --- The SHA256 cryptographic hash function.
--
-- @module sha256
--
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local lassert = require "ccryptolib.internal.util".lassert local lassert = require "ccryptolib.internal.util".lassert
@ -79,10 +76,8 @@ local function compress(h, w)
end end
--- Hashes data using SHA256. --- Hashes data using SHA256.
-- --- @param data string Input bytes.
-- @tparam string data Input bytes. --- @return string hash The 32-byte hash value.
-- @treturn string The 32-byte hash value.
--
local function digest(data) local function digest(data)
expect(1, data, "string") expect(1, data, "string")
@ -101,12 +96,10 @@ local function digest(data)
end end
--- Hashes a password using PBKDF2-HMAC-SHA256. --- Hashes a password using PBKDF2-HMAC-SHA256.
-- --- @param password string The password to hash.
-- @tparam password string The password to hash. --- @param salt string The password's salt.
-- @tparam salt string The password's salt. --- @param iter number The number of iterations to perform.
-- @tparam iter number The number of iterations to perform. --- @return string dk The 32-byte derived key.
-- @treturn string The 32-byte derived key.
--
local function pbkdf2(password, salt, iter) local function pbkdf2(password, salt, iter)
expect(1, password, "string") expect(1, password, "string")
expect(2, salt, "string") expect(2, salt, "string")

View file

@ -1,20 +1,15 @@
--- The X25519 key exchange scheme. --- The X25519 key exchange scheme.
--
-- @module x25519
--
local expect = require "cc.expect".expect local expect = require "cc.expect".expect
local lassert = require "ccryptolib.internal.util".lassert local lassert = require "ccryptolib.internal.util".lassert
local util = require "ccryptolib.internal.util" local util = require "ccryptolib.internal.util"
local c25 = require "ccryptolib.internal.curve25519" local c25 = require "ccryptolib.internal.curve25519"
local mod = {} local mod = {}
--- Computes the public key from a secret key. --- Computes the public key from a secret key.
-- --- @param sk string A random 32-byte secret key.
-- @tparam string sk A random 32-byte secret key. --- @return string pk The matching public key.
-- @treturn string The matching public key.
--
function mod.publicKey(sk) function mod.publicKey(sk)
expect(1, sk, "string") expect(1, sk, "string")
assert(#sk == 32, "secret key length must be 32") assert(#sk == 32, "secret key length must be 32")
@ -22,25 +17,27 @@ function mod.publicKey(sk)
end end
--- Performs the key exchange. --- Performs the key exchange.
-- --- @param sk string A Curve25519 secret key.
-- @tparam string sk A secret key. --- @param pk string A public key, usually derived from someone else's secret key.
-- @tparam string pk A public key, usually derived from a second secret key. --- @return string ss The 32-byte shared secret between both keys.
-- @treturn string The 32-byte shared secret between both keys.
--
function mod.exchange(sk, pk) function mod.exchange(sk, pk)
expect(1, sk, "string") expect(1, sk, "string")
lassert(#sk == 32, "secret key length must be 32", 2) lassert(#sk == 32, "secret key length must be 32", 2)
expect(2, pk, "string") expect(2, pk, "string")
lassert(#pk == 32, "public key length must be 32", 2) lassert(#pk == 32, "public key length must be 32", 2) --- @cast pk String32
return c25.encode(c25.scale(c25.ladder8(c25.decode(pk), util.bits8(sk)))) return c25.encode(c25.scale(c25.ladder8(c25.decode(pk), util.bits8(sk))))
end end
--- Same as @{exchange}, but decodes the public key as an Edwards25519 point. --- Performs the key exchange, but decoding the public key as an Edwards25519
--- point, using the birational map.
--- @param sk string A Curve25519 secret key
--- @param pk string An Edwards25519 public key, usually derived from someone else's secret key.
--- @return string ss The 32-byte shared secret between both keys.
function mod.exchangeEd(sk, pk) function mod.exchangeEd(sk, pk)
expect(1, sk, "string") expect(1, sk, "string")
lassert(#sk == 32, "secret key length must be 32", 2) lassert(#sk == 32, "secret key length must be 32", 2)
expect(2, pk, "string") expect(2, pk, "string")
lassert(#pk == 32, "public key length must be 32", 2) lassert(#pk == 32, "public key length must be 32", 2) --- @cast pk String32
return c25.encode(c25.scale(c25.ladder8(c25.decodeEd(pk), util.bits8(sk)))) return c25.encode(c25.scale(c25.ladder8(c25.decodeEd(pk), util.bits8(sk))))
end end

View file

@ -8,6 +8,8 @@ local sha512 = require "ccryptolib.internal.sha512"
local random = require "ccryptolib.random" local random = require "ccryptolib.random"
--- Masks an exchange secret key. --- Masks an exchange secret key.
--- @param sk string A random 32-byte Curve25519 secret key.
--- @return string msk A masked secret key.
local function maskX(sk) local function maskX(sk)
expect(1, sk, "string") expect(1, sk, "string")
lassert(#sk == 32, "secret key length must be 32", 2) lassert(#sk == 32, "secret key length must be 32", 2)
@ -19,6 +21,8 @@ local function maskX(sk)
end end
--- Masks a signature secret key. --- Masks a signature secret key.
--- @param sk string A random 32-byte Edwards25519 secret key.
--- @return string msk A masked secret key.
function maskS(sk) function maskS(sk)
expect(1, sk, "string") expect(1, sk, "string")
lassert(#sk == 32, "secret key length must be 32", 2) lassert(#sk == 32, "secret key length must be 32", 2)
@ -26,27 +30,29 @@ function maskS(sk)
end end
--- Rerandomizes the masking on a masked key. --- Rerandomizes the masking on a masked key.
local function remask(sk) --- @param msk string A masked secret key.
expect(1, sk, "string") --- @return string msk The same secret key, but with another mask.
lassert(#sk == 64, "masked secret key length must be 64", 2) local function remask(msk)
expect(1, msk, "string")
lassert(#msk == 64, "masked secret key length must be 64", 2)
local newMask = random.random(32) local newMask = random.random(32)
local xr = fq.decode(sk:sub(1, 32)) local xr = fq.decode(msk:sub(1, 32))
local r = fq.decodeClamped(sk:sub(33)) local r = fq.decodeClamped(msk:sub(33))
local s = fq.decodeClamped(newMask) local s = fq.decodeClamped(newMask)
local xs = fq.add(xr, fq.sub(r, s)) local xs = fq.add(xr, fq.sub(r, s))
return fq.encode(xs) .. newMask return fq.encode(xs) .. newMask
end end
--- Returns the ephemeral exchange secret key of this masked key. --- Returns the ephemeral exchange secret key of this masked key.
-- --- This is the second secret key in the "double key exchange" in @{exchange},
-- This is the second secret key in the "double key exchange" in @{exchange}, --- the first being the key that has been masked. The ephemeral key changes
-- the first being the key that has been masked. The ephemeral key changes every --- every time @{remask} is called.
-- time @{remask} is called. --- @param msk string A masked secret key.
-- --- @return string esk The ephemeral half of the masked secret key.
local function ephemeralSk(sk) local function ephemeralSk(msk)
expect(1, sk, "string") expect(1, msk, "string")
lassert(#sk == 64, "masked secret key length must be 64", 2) lassert(#msk == 64, "masked secret key length must be 64", 2)
return sk:sub(33) return msk:sub(33)
end end
local function exchangeOnPoint(sk, P) local function exchangeOnPoint(sk, P)
@ -108,54 +114,69 @@ local function exchangeOnPoint(sk, P)
end end
--- Returns the X25519 public key of this masked key. --- Returns the X25519 public key of this masked key.
local function publicKeyX(sk) --- @param msk string A masked secret key.
expect(1, sk, "string") local function publicKeyX(msk)
lassert(#sk == 64, "masked secret key length must be 64", 2) expect(1, msk, "string")
return (exchangeOnPoint(sk, c25.G)) lassert(#msk == 64, "masked secret key length must be 64", 2)
return (exchangeOnPoint(msk, c25.G))
end end
--- Returns the Ed25519 public key of this masked key. --- Returns the Ed25519 public key of this masked key.
local function publicKeyS(sk) --- @param msk string A masked secret key.
expect(1, sk, "string") --- @return string pk The Ed25519 public key matching this masked key.
lassert(#sk == 64, "masked secret key length must be 64", 2) local function publicKeyS(msk)
local xr = fq.decode(sk:sub(1, 32)) expect(1, msk, "string")
local r = fq.decodeClamped(sk:sub(33)) lassert(#msk == 64, "masked secret key length must be 64", 2)
local xr = fq.decode(msk:sub(1, 32))
local r = fq.decodeClamped(msk:sub(33))
local y = ed.add(ed.mulG(fq.bits(xr)), ed.niels(ed.mulG(fq.bits(r)))) local y = ed.add(ed.mulG(fq.bits(xr)), ed.niels(ed.mulG(fq.bits(r))))
return ed.encode(ed.scale(y)) return ed.encode(ed.scale(y))
end end
--- Performs a double key exchange. --- Performs a double key exchange.
-- ---
-- Returns 0 if the input public key has small order or if it isn't in the base --- Returns 0 if the input public key has small order or if it isn't in the base
-- curve. This is different from standard X25519, which performs the exchange --- curve. This is different from standard X25519, which performs the exchange
-- even on the twist. --- even on the twist.
-- ---
-- May incorrectly return 0 with negligible chance if the mask happens to match --- May incorrectly return 0 with negligible chance if the mask happens to match
-- the masked key. I haven't checked if clamping prevents that from happening. --- the masked key. I haven't checked if clamping prevents that from happening.
-- ---
--- @param sk string A masked secret key.
--- @param pk string An X25519 public key.
--- @return string sss The shared secret between the public key and the static half of the masked key.
--- @return string sse The shared secret betwen the public key and the ephemeral half of the masked key.
local function exchangeX(sk, pk) local function exchangeX(sk, pk)
expect(1, sk, "string") expect(1, sk, "string")
lassert(#sk == 64, "masked secret key length must be 64", 2) lassert(#sk == 64, "masked secret key length must be 64", 2)
expect(2, pk, "string") expect(2, pk, "string")
lassert(#pk == 32, "public key length must be 32", 2) lassert(#pk == 32, "public key length must be 32", 2) --- @cast pk String32
return exchangeOnPoint(sk, c25.decode(pk)) return exchangeOnPoint(sk, c25.decode(pk))
end end
--- Performs an exchange against an Ed25519 key. --- Performs an exchange against an Ed25519 key.
-- ---
-- This is done by converting the key into X25519 before passing it to the --- This is done by converting the key into X25519 before passing it to the
-- regular exchange. Using this function on the result of @{signaturePk} leads --- regular exchange. Using this function on the result of @{signaturePk} leads
-- to the same value as using @{exchange} on the result of @{exchangePk}. --- to the same value as using @{exchange} on the result of @{exchangePk}.
-- ---
--- @param sk string A masked secret key.
--- @param pk string An Ed25519 public key.
--- @return string sss The shared secret between the public key and the static half of the masked key.
--- @return string sse The shared secret betwen the public key and the ephemeral half of the masked key.
local function exchangeS(sk, pk) local function exchangeS(sk, pk)
expect(1, sk, "string") expect(1, sk, "string")
lassert(#sk == 64, "masked secret key length must be 64", 2) lassert(#sk == 64, "masked secret key length must be 64", 2)
expect(2, pk, "string") expect(2, pk, "string")
lassert(#pk == 32, "public key length must be 32", 2) lassert(#pk == 32, "public key length must be 32", 2) --- @cast pk String32
return exchangeOnPoint(sk, c25.decodeEd(pk)) return exchangeOnPoint(sk, c25.decodeEd(pk))
end end
--- Signs a message using Ed25519. --- Signs a message using Ed25519.
--- @param sk string A masked secret key.
--- @param pk string The Ed25519 public key matching the secret key.
--- @param msg string A message to sign.
--- @return string sig The signature on the message.
local function sign(sk, pk, msg) local function sign(sk, pk, msg)
expect(1, sk, "string") expect(1, sk, "string")
lassert(#sk == 64, "masked secret key length must be 64", 2) lassert(#sk == 64, "masked secret key length must be 64", 2)

View file

@ -16,15 +16,15 @@ describe("aead.encrypt", function()
-- Types -- Types
expect.error(aead.encrypt, nil, nonce, msg, aad, rounds) expect.error(aead.encrypt, nil, nonce, msg, aad, rounds)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
expect.error(aead.encrypt, key, nil, msg, aad, rounds) expect.error(aead.encrypt, key, nil, msg, aad, rounds)
:eq("bad argument #2 (expected string, got nil)") :eq("bad argument #2 (string expected, got nil)")
expect.error(aead.encrypt, key, nonce, nil, aad, rounds) expect.error(aead.encrypt, key, nonce, nil, aad, rounds)
:eq("bad argument #3 (expected string, got nil)") :eq("bad argument #3 (string expected, got nil)")
expect.error(aead.encrypt, key, nonce, msg, nil, rounds) expect.error(aead.encrypt, key, nonce, msg, nil, rounds)
:eq("bad argument #4 (expected string, got nil)") :eq("bad argument #4 (string expected, got nil)")
expect.error(aead.encrypt, key, nonce, msg, aad, {}) expect.error(aead.encrypt, key, nonce, msg, aad, {})
:eq("bad argument #5 (expected number, got table)") :eq("bad argument #5 (number expected, got table)")
-- String lengths -- String lengths
expect.error(aead.encrypt, key .. "a", nonce, msg, aad, rounds) expect.error(aead.encrypt, key .. "a", nonce, msg, aad, rounds)
@ -155,17 +155,17 @@ describe("aead.decrypt", function()
-- Types -- Types
expect.error(aead.decrypt, nil, nonce, tag, ctx, aad, rounds) expect.error(aead.decrypt, nil, nonce, tag, ctx, aad, rounds)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
expect.error(aead.decrypt, key, nil, tag, ctx, aad, rounds) expect.error(aead.decrypt, key, nil, tag, ctx, aad, rounds)
:eq("bad argument #2 (expected string, got nil)") :eq("bad argument #2 (string expected, got nil)")
expect.error(aead.decrypt, key, nonce, nil, ctx, aad, rounds) expect.error(aead.decrypt, key, nonce, nil, ctx, aad, rounds)
:eq("bad argument #3 (expected string, got nil)") :eq("bad argument #3 (string expected, got nil)")
expect.error(aead.decrypt, key, nonce, tag, nil, aad, rounds) expect.error(aead.decrypt, key, nonce, tag, nil, aad, rounds)
:eq("bad argument #4 (expected string, got nil)") :eq("bad argument #4 (string expected, got nil)")
expect.error(aead.decrypt, key, nonce, tag, ctx, nil, rounds) expect.error(aead.decrypt, key, nonce, tag, ctx, nil, rounds)
:eq("bad argument #5 (expected string, got nil)") :eq("bad argument #5 (string expected, got nil)")
expect.error(aead.decrypt, key, nonce, tag, ctx, aad, {}) expect.error(aead.decrypt, key, nonce, tag, ctx, aad, {})
:eq("bad argument #6 (expected number, got table)") :eq("bad argument #6 (number expected, got table)")
-- String lengths -- String lengths
expect.error(aead.decrypt, key .. "a", nonce, tag, ctx, aad, rounds) expect.error(aead.decrypt, key .. "a", nonce, tag, ctx, aad, rounds)

View file

@ -17,9 +17,9 @@ describe("blake3.digest", function()
it("validates arguments", function() it("validates arguments", function()
-- Types -- Types
expect.error(blake3.digest, nil) expect.error(blake3.digest, nil)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
expect.error(blake3.digest, "", {}) expect.error(blake3.digest, "", {})
:eq("bad argument #2 (expected number, got table)") :eq("bad argument #2 (number expected, got table)")
-- Length -- Length
expect.error(blake3.digest, "", 0.5) expect.error(blake3.digest, "", 0.5)
@ -51,11 +51,11 @@ describe("blake3.digestKeyed", function()
-- Types -- Types
expect.error(blake3.digestKeyed, nil, "") expect.error(blake3.digestKeyed, nil, "")
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
expect.error(blake3.digestKeyed, key, nil) expect.error(blake3.digestKeyed, key, nil)
:eq("bad argument #2 (expected string, got nil)") :eq("bad argument #2 (string expected, got nil)")
expect.error(blake3.digestKeyed, key, "", {}) expect.error(blake3.digestKeyed, key, "", {})
:eq("bad argument #3 (expected number, got table)") :eq("bad argument #3 (number expected, got table)")
-- String lengths -- String lengths
expect.error(blake3.digestKeyed, key .. "a", "") expect.error(blake3.digestKeyed, key .. "a", "")
@ -90,11 +90,11 @@ describe("blake3.deriveKey", function()
it("validates arguments", function() it("validates arguments", function()
-- Types -- Types
expect.error(blake3.deriveKey, nil) expect.error(blake3.deriveKey, nil)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
expect.error(blake3.deriveKey(""), nil) expect.error(blake3.deriveKey(""), nil)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
expect.error(blake3.deriveKey(""), "", {}) expect.error(blake3.deriveKey(""), "", {})
:eq("bad argument #2 (expected number, got table)") :eq("bad argument #2 (number expected, got table)")
-- Length -- Length
expect.error(blake3.deriveKey(""), "", 0.5) expect.error(blake3.deriveKey(""), "", 0.5)

View file

@ -16,15 +16,15 @@ describe("chacha20.crypt", function()
-- Types -- Types
expect.error(chacha20.crypt, nil, nonce, msg, rounds, offset) expect.error(chacha20.crypt, nil, nonce, msg, rounds, offset)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
expect.error(chacha20.crypt, key, nil, msg, rounds, offset) expect.error(chacha20.crypt, key, nil, msg, rounds, offset)
:eq("bad argument #2 (expected string, got nil)") :eq("bad argument #2 (string expected, got nil)")
expect.error(chacha20.crypt, key, nonce, nil, rounds, offset) expect.error(chacha20.crypt, key, nonce, nil, rounds, offset)
:eq("bad argument #3 (expected string, got nil)") :eq("bad argument #3 (string expected, got nil)")
expect.error(chacha20.crypt, key, nonce, msg, {}, offset) expect.error(chacha20.crypt, key, nonce, msg, {}, offset)
:eq("bad argument #4 (expected number, got table)") :eq("bad argument #4 (number expected, got table)")
expect.error(chacha20.crypt, key, nonce, msg, nil, {}) expect.error(chacha20.crypt, key, nonce, msg, nil, {})
:eq("bad argument #5 (expected number, got table)") :eq("bad argument #5 (number expected, got table)")
-- String lengths -- String lengths
expect.error(chacha20.crypt, key .. "a", nonce, msg, rounds, offset) expect.error(chacha20.crypt, key .. "a", nonce, msg, rounds, offset)

View file

@ -13,9 +13,9 @@ describe("poly1305.mac", function()
-- Types -- Types
expect.error(poly1305.mac, nil, msg) expect.error(poly1305.mac, nil, msg)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
expect.error(poly1305.mac, key, nil) expect.error(poly1305.mac, key, nil)
:eq("bad argument #2 (expected string, got nil)") :eq("bad argument #2 (string expected, got nil)")
-- Key length -- Key length
expect.error(poly1305.mac, key .. "a", msg) expect.error(poly1305.mac, key .. "a", msg)

View file

@ -12,7 +12,7 @@ local longMsg = require "spec.vec.sha256.long"
describe("sha256.digest", function() describe("sha256.digest", function()
it("validates arguments", function() it("validates arguments", function()
expect.error(sha256.digest, nil) expect.error(sha256.digest, nil)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
end) end)
it("passes the NIST SHAVS byte-oriented short messages test", function() it("passes the NIST SHAVS byte-oriented short messages test", function()

View file

@ -9,10 +9,10 @@ local sha512 = require "ccryptolib.internal.sha512"
local shortMsg = require "spec.vec.sha512.short" local shortMsg = require "spec.vec.sha512.short"
local longMsg = require "spec.vec.sha512.long" local longMsg = require "spec.vec.sha512.long"
describe("sha256.digest", function() describe("sha512.digest", function()
it("validates arguments", function() it("validates arguments", function()
expect.error(sha512.digest, nil) expect.error(sha512.digest, nil)
:eq("bad argument #1 (expected string, got nil)") :eq("bad argument #1 (string expected, got nil)")
end) end)
it("passes the NIST SHAVS byte-oriented short messages test", function() it("passes the NIST SHAVS byte-oriented short messages test", function()