summaryrefslogtreecommitdiff
path: root/src/random.c
blob: 90fb3f2b3b74e8c1aaf164f7ed046d67e053f102 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/// @submodule system
#include <lua.h>
#include <lauxlib.h>
#include "compat.h"
#include <fcntl.h>

#ifdef _WIN32
#include "windows.h"
#include "wincrypt.h"
#else
#include <errno.h>
#include <unistd.h>
#include <string.h>
#endif


/***
Generate random bytes.
This uses `CryptGenRandom()` on Windows, and `/dev/urandom` on other platforms. It will return the
requested number of bytes, or an error, never a partial result.
@function random
@tparam[opt=1] int length number of bytes to get
@treturn[1] string string of random bytes
@treturn[2] nil
@treturn[2] string error message
*/
static int lua_get_random_bytes(lua_State* L) {
    int num_bytes = luaL_optinteger(L, 1, 1); // Number of bytes, default to 1 if not provided

    if (num_bytes <= 0) {
        if (num_bytes == 0) {
            lua_pushliteral(L, "");
            return 1;
        }
        lua_pushnil(L);
        lua_pushstring(L, "invalid number of bytes, must not be less than 0");
        return 2;
    }

    unsigned char* buffer = (unsigned char*)lua_newuserdata(L, num_bytes);
    if (buffer == NULL) {
        lua_pushnil(L);
        lua_pushstring(L, "failed to allocate memory for random buffer");
        return 2;
    }

    ssize_t n;
    ssize_t total_read = 0;

#ifdef _WIN32
    HCRYPTPROV hCryptProv;
    if (!CryptAcquireContext(&hCryptProv, NULL, NULL, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT)) {
        DWORD error = GetLastError();
        lua_pushnil(L);
        lua_pushfstring(L, "failed to acquire cryptographic context: %lu", error);
        return 2;
    }

    if (!CryptGenRandom(hCryptProv, num_bytes, buffer)) {
        DWORD error = GetLastError();
        lua_pushnil(L);
        lua_pushfstring(L, "failed to get random data: %lu", error);
        CryptReleaseContext(hCryptProv, 0);
        return 2;
    }

    CryptReleaseContext(hCryptProv, 0);
#else

    // for macOS/unixes use /dev/urandom for non-blocking
    int fd = open("/dev/urandom", O_RDONLY | O_CLOEXEC);
    if (fd < 0) {
        lua_pushnil(L);
        lua_pushstring(L, "failed opening /dev/urandom");
        return 2;
    }

    while (total_read < num_bytes) {
        n = read(fd, buffer + total_read, num_bytes - total_read);

        if (n < 0) {
            if (errno == EINTR) {
                continue;  // Interrupted, retry

            } else {
                lua_pushnil(L);
                lua_pushfstring(L, "failed reading /dev/urandom: %s", strerror(errno));
                close(fd);
                return 2;
            }
        }

        total_read += n;
    }

    close(fd);
#endif

    lua_pushlstring(L, (const char*)buffer, num_bytes);
    return 1;
}



static luaL_Reg func[] = {
    { "random", lua_get_random_bytes },
    { NULL, NULL }
};



/*-------------------------------------------------------------------------
 * Initializes module
 *-------------------------------------------------------------------------*/
void random_open(lua_State *L) {
    luaL_setfuncs(L, func, 0);
}