summaryrefslogtreecommitdiff
path: root/script/plugins/ffi/c-parser/typed.lua
blob: c84b87e3c563efe42225d703ac5d165c8ec14b2a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
--------------------------------------------------------------------------------
-- Lua programming with types
--------------------------------------------------------------------------------

local _, inspect = pcall(require, "inspect")
inspect = inspect or tostring

local typed = {}

local FAST = false

local function is_sequence(xs)
   if type(xs) ~= "table" then
      return false
   end
   if FAST then
      return true
   end
   local l = #xs
   for k, _ in pairs(xs) do
      if type(k) ~= "number" or k < 1 or k > l or math.floor(k) ~= k then
         return false
      end
   end
   return true
end

local function type_of(t)
   local mt = getmetatable(t)
   return (mt and mt.__name) or (is_sequence(t) and "array") or type(t)
end

local function set_type(t, typ)
   local mt = getmetatable(t)
   if not mt then
      mt = {}
   end
   mt.__name = typ
   return setmetatable(t, mt)
end

local function typed_table(typ, t)
   return set_type(t, typ)
end

local function try_check(val, expected)
   local optional = expected:match("^(.*)%?$")
   if optional then
      if val == nil then
         return true
      end
      expected = optional
   end

   local seq_type = expected:match("^{(.+)}$")
   if seq_type then
      if type(val) == "table" then
         if FAST then
            return true
         end
         local allok = true
         for _, v in ipairs(val) do
            local ok = try_check(v, seq_type)
            if not ok then
               allok = false
               break
            end
         end
         if allok then
            return true
         end
      end
   end

   -- if all we want is a table, don't perform further checks
   if expected == "table" and type(val) == "table" then
      return true
   end

   local actual = type_of(val)
   if actual == expected then
      return true
   end
   return nil, actual
end

local function typed_check(val, expected, category, n)
   local ok, actual = try_check(val, expected)
   if ok then
      return true
   end
   if category and n then
      error(("type error: %s %d: expected %s, got %s (%s)"):format(category, n, expected, actual, inspect(val)), category == "value" and 2 or 3)
   else
      error(("type error: expected %s, got %s (%s)"):format(expected, actual, inspect(val)), 2)
   end
end

local function split(s, sep)
   local i, j, k = 1, s:find(sep, 1)
   local out = {}
   while j do
      table.insert(out, s:sub(i, j - 1))
      i = k + 1
      j, k = s:find(sep, i)
   end
   table.insert(out, s:sub(i, #s))
   return out
end

local function typed_function(types, fn)
   local inp, outp = types:match("(.*[^%s])%s*%->%s*([^%s].*)")
   local ins = split(inp, ",%s*")
   local outs = split(outp, ",%s*")
   return function(...)
      local args = table.pack(...)
      if args.n ~= #ins then
         error("wrong number of inputs (given " .. args.n .. " - expects " .. types .. ")", 2)
      end
      for i = 1, #ins do
         typed_check(args[i], ins[i], "argument", i)
      end
      local rets = table.pack(fn(...))
      if outp == "()" then
         if rets.n ~= 0 then
            error("wrong number of outputs (given " .. rets.n .. " - expects " .. types .. ")", 2)
         end
      else
         if rets.n ~= #outs then
            error("wrong number of outputs (given " .. rets.n .. " - expects " .. types .. ")", 2)
         end
         if outs[1] ~= "*" then
            for i = 1, #outs do
               typed_check(rets[i], outs[i], "return", i)
            end
         end
      end
      return table.unpack(rets, 1, rets.n)
   end
end

local typed_mt_on = {
   __call = function(_, types, fn)
       return typed_function(types, fn)
   end
}

local typed_mt_off = {
   __call = function(_, _, fn)
       return fn
   end
}

function typed.on()
   typed.check = typed_check
   typed.typed = typed_function
   typed.set_type = set_type
   typed.table = typed_table
   setmetatable(typed, typed_mt_on)
end

function typed.off()
   typed.check = function() end
   typed.typed = function(_, fn) return fn end
   typed.set_type = function(t, _) return t end
   typed.table = function(_, t) return t end
   setmetatable(typed, typed_mt_off)
end

typed.off()

return typed