Skip to content

Commit

Permalink
fix(random): review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tieske committed Nov 15, 2023
1 parent 1dd4e45 commit 94dd3bb
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 39 deletions.
16 changes: 8 additions & 8 deletions spec/02-random_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,32 @@ describe("Random:", function()
describe("random()", function()

it("should return random bytes for a valid number of bytes", function()
local num_bytes = system.MAX_RANDOM_BUFFER_SIZE
local num_bytes = 1
local result, err_msg = system.random(num_bytes)
assert.is_nil(err_msg)
assert.is.string(result)
assert.is_equal(num_bytes, #result)
end)


it("should return an error message for an invalid number of bytes", function()
it("should return an empty string for 0 bytes", function()
local num_bytes = 0
local result, err_msg = system.random(num_bytes)
assert.is.falsy(result)
assert.are.equal("invalid number of bytes, must be between 1 and 1024", err_msg)
assert.is_nil(err_msg)
assert.are.equal("", result)
end)


it("should return an error message for exceeding the maximum buffer size", function()
local num_bytes = system.MAX_RANDOM_BUFFER_SIZE + 1
it("should return an error message for an invalid number of bytes", function()
local num_bytes = -1
local result, err_msg = system.random(num_bytes)
assert.is.falsy(result)
assert.are.equal("invalid number of bytes, must be between 1 and 1024", err_msg)
assert.are.equal("invalid number of bytes, must not be less than 0", err_msg)
end)


it("should not return duplicate results", function()
local num_bytes = 10
local num_bytes = 1025
local result1, err_msg = system.random(num_bytes)
assert.is_nil(err_msg)
assert.is.string(result1)
Expand Down
27 changes: 27 additions & 0 deletions src/compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,31 @@
void luaL_setfuncs(lua_State *L, const luaL_Reg *l, int nup);
#endif

// Windows doesn't have ssize_t, so we define it here
#ifdef _WIN32
#if SIZE_MAX == UINT_MAX
typedef int ssize_t; /* common 32 bit case */
#define SSIZE_MIN INT_MIN
#define SSIZE_MAX INT_MAX
#elif SIZE_MAX == ULONG_MAX
typedef long ssize_t; /* linux 64 bits */
#define SSIZE_MIN LONG_MIN
#define SSIZE_MAX LONG_MAX
#elif SIZE_MAX == ULLONG_MAX
typedef long long ssize_t; /* windows 64 bits */
#define SSIZE_MIN LLONG_MIN
#define SSIZE_MAX LLONG_MAX
#elif SIZE_MAX == USHRT_MAX
typedef short ssize_t; /* is this even possible? */
#define SSIZE_MIN SHRT_MIN
#define SSIZE_MAX SHRT_MAX
#elif SIZE_MAX == UINTMAX_MAX
typedef uintmax_t ssize_t; /* last resort, chux suggestion */
#define SSIZE_MIN INTMAX_MIN
#define SSIZE_MAX INTMAX_MAX
#else
#error platform has exotic SIZE_MAX
#endif
#endif

#endif
71 changes: 40 additions & 31 deletions src/random.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,44 @@
#include "windows.h"
#include "wincrypt.h"
#else
#include <errno.h>
#include <unistd.h>
#include <string.h>
#endif


// Maximum buffer size for random bytes
#define MAX_RANDOM_BUFFER_SIZE 1024


/***
Generate random bytes.
This uses `getrandom()` on Linux, `CryptGenRandom()` on Windows, and `/dev/urandom` on macOS.
This uses `CryptGenRandom()` on Windows, and `/dev/urandom` on other platforms. It will return the
requested number of bytes, or an error, never a partial result.
@function random
@tparam[opt=1] int length number of bytes to get, must be less than or equal to `MAX_RANDOM_BUFFER_SIZE` (1024)
@tparam[opt=1] int length number of bytes to get
@treturn[1] string string of random bytes
@treturn[2] nil
@treturn[2] string error message
*/
static int lua_get_random_bytes(lua_State* L) {
int num_bytes = luaL_optinteger(L, 1, 1); // Number of bytes, default to 1 if not provided

if (num_bytes <= 0 || num_bytes > MAX_RANDOM_BUFFER_SIZE) {
if (num_bytes <= 0) {
if (num_bytes == 0) {
lua_pushliteral(L, "");
return 1;
}
lua_pushnil(L);
lua_pushfstring(L, "invalid number of bytes, must be between 1 and %d", MAX_RANDOM_BUFFER_SIZE);
lua_pushstring(L, "invalid number of bytes, must not be less than 0");
return 2;
}

unsigned char buffer[MAX_RANDOM_BUFFER_SIZE];
size_t n;
unsigned char* buffer = (unsigned char*)lua_newuserdata(L, num_bytes);
if (buffer == NULL) {
lua_pushnil(L);
lua_pushstring(L, "failed to allocate memory for random buffer");
return 2;
}

ssize_t n;
ssize_t total_read = 0;

#ifdef _WIN32
HCRYPTPROV hCryptProv;
Expand All @@ -56,33 +66,34 @@ static int lua_get_random_bytes(lua_State* L) {

CryptReleaseContext(hCryptProv, 0);
#else
#ifndef __APPLE__
// Neither Apple nor Windows
n = getrandom(buffer, num_bytes, 0);

if (n < 0) {
lua_pushnil(L);
lua_pushstring(L, "failed to get random data");
return 2;
}
#else
// macOS uses /dev/urandom for non-blocking
int fd = open("/dev/urandom", O_RDONLY);
// for macOS/unixes use /dev/urandom for non-blocking
int fd = open("/dev/urandom", O_RDONLY | O_CLOEXEC);
if (fd < 0) {
lua_pushnil(L);
lua_pushstring(L, "failed opening /dev/urandom");
return 2;
}
n = read(fd, buffer, num_bytes);
close(fd);
if (n < 0) {
lua_pushnil(L);
lua_pushstring(L, "failed reading /dev/urandom");
return 2;
}

#endif
while (total_read < num_bytes) {
n = read(fd, buffer + total_read, num_bytes - total_read);

if (n < 0) {
if (errno == EINTR) {
continue; // Interrupted, retry

} else {
lua_pushnil(L);
lua_pushfstring(L, "failed reading /dev/urandom: %s", strerror(errno));
close(fd);
return 2;
}
}

total_read += n;
}

close(fd);
#endif

lua_pushlstring(L, (const char*)buffer, num_bytes);
Expand All @@ -103,6 +114,4 @@ static luaL_Reg func[] = {
*-------------------------------------------------------------------------*/
void random_open(lua_State *L) {
luaL_setfuncs(L, func, 0);
lua_pushinteger(L, MAX_RANDOM_BUFFER_SIZE);
lua_setfield(L, -2, "MAX_RANDOM_BUFFER_SIZE");
}

0 comments on commit 94dd3bb

Please sign in to comment.