summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/infer.lua170
-rw-r--r--test/type_inference/init.lua12
2 files changed, 105 insertions, 77 deletions
diff --git a/script/core/infer.lua b/script/core/infer.lua
index c70662a0..fe097915 100644
--- a/script/core/infer.lua
+++ b/script/core/infer.lua
@@ -12,6 +12,89 @@ local function mergeTable(a, b)
end
end
+local function searchInferOfUnary(value, infers)
+ local op = value.op.type
+ if op == 'not' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '#' then
+ infers['integer'] = true
+ return
+ end
+ if op == '-' then
+ if m.hasType(value[1], 'integer') then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+end
+
+local function searchInferOfBinary(value, infers)
+ local op = value.op.type
+ if op == 'and' then
+ if m.isTrue(value[1]) then
+ mergeTable(infers, m.searchInfers(value[2]))
+ else
+ mergeTable(infers, m.searchInfers(value[1]))
+ end
+ return
+ end
+ if op == 'or' then
+ if m.isTrue(value[1]) then
+ mergeTable(infers, m.searchInfers(value[1]))
+ else
+ mergeTable(infers, m.searchInfers(value[2]))
+ end
+ return
+ end
+ if op == '=='
+ or op == '~='
+ or op == '<'
+ or op == '>'
+ or op == '<='
+ or op == '>=' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '&'
+ or op == '|' then
+ infers['integer'] = true
+ return
+ end
+ if op == '..' then
+ infers['string'] = true
+ return
+ end
+ if op == '^'
+ or op == '/' then
+ infers['number'] = true
+ return
+ end
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '%'
+ or op == '//' then
+ if m.hasType(value[1], 'integer')
+ and m.hasType(value[2], 'integer') then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+end
+
local function searchInferOfValue(value, infers)
if value.type == 'string' then
infers['string'] = true
@@ -38,86 +121,12 @@ local function searchInferOfValue(value, infers)
return
end
if value.type == 'unary' then
- local op = value.op.type
- if op == 'not' then
- infers['boolean'] = true
- return
- end
- if op == '#' then
- infers['integer'] = true
- return
- end
- if op == '-' then
- if m.hasType(value[1], 'integer') then
- infers['integer'] = true
- else
- infers['number'] = true
- end
- return
- end
- if op == '~' then
- infers['integer'] = true
- return
- end
+ searchInferOfUnary(value, infers)
return
end
if value.type == 'binary' then
- local op = value.op.type
- if op == 'and' then
- if m.isTrue(value[1]) then
- mergeTable(infers, m.searchInfers(value[2]))
- else
- mergeTable(infers, m.searchInfers(value[1]))
- end
- return
- end
- if op == 'or' then
- if m.isTrue(value[1]) then
- mergeTable(infers, m.searchInfers(value[1]))
- else
- mergeTable(infers, m.searchInfers(value[2]))
- end
- return
- end
- if op == '=='
- or op == '~='
- or op == '<'
- or op == '>'
- or op == '<='
- or op == '>=' then
- infers['boolean'] = true
- return
- end
- if op == '<<'
- or op == '>>'
- or op == '~'
- or op == '&'
- or op == '|' then
- infers['integer'] = true
- return
- end
- if op == '..' then
- infers['string'] = true
- return
- end
- if op == '^'
- or op == '/' then
- infers['number'] = true
- return
- end
- if op == '+'
- or op == '-'
- or op == '*'
- or op == '%'
- or op == '//' then
- if m.hasType(value[1], 'integer')
- and m.hasType(value[2], 'integer') then
- infers['integer'] = true
- else
- infers['number'] = true
- end
- return
- end
+ searchInferOfBinary(value, infers)
+ return
end
end
@@ -198,6 +207,13 @@ local function searchInfer(source, infers)
searchInferOfValue(value, infers)
return
end
+ if source.type == 'doc.class.name' then
+ local name = source[1]
+ if name then
+ infers[name] = true
+ end
+ return
+ end
-- X.a
if source.next and source.next.node == source then
if source.next.type == 'setfield'
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index df8f4fe8..0a041433 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -153,14 +153,26 @@ _VERSION = 'Lua 5.4'
]]
TEST 'function' [[
+---@class stringlib
+local string
+
+string.sub = function () end
+
return ('x').<?sub?>
]]
TEST 'function' [[
+---@class stringlib
+local string
+
+string.sub = function () end
+
<?x?> = ('x').sub
]]
TEST 'function' [[
+_VERSION = 'Lua 5.4'
+
<?x?> = _VERSION.sub
]]