summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/infer.lua51
-rw-r--r--test/type_inference/init.lua5
2 files changed, 51 insertions, 5 deletions
diff --git a/script/core/infer.lua b/script/core/infer.lua
index 716fbcf9..14ec6be2 100644
--- a/script/core/infer.lua
+++ b/script/core/infer.lua
@@ -1,5 +1,6 @@
local searcher = require 'core.searcher'
local config = require 'config'
+local linker = require 'core.linker'
local BE_LEN = {'#'}
local CLASS = {'CLASS'}
@@ -260,10 +261,43 @@ local function searchInfer(source, infers)
infers['any'] = true
return
end
- -- # XX -> string | table
- if source.parent.type == 'unary'
- and source.parent.op.type == '#' then
- infers[BE_LEN] = true
+ if source.parent.type == 'unary' then
+ local op = source.parent.op.type
+ -- # XX -> string | table
+ if op == '#' then
+ infers[BE_LEN] = true
+ return
+ end
+ if op == '-' then
+ infers['number'] = true
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+ return
+ end
+ if source.parent.type == 'binary' then
+ local op = source.parent.op.type
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '/'
+ or op == '//'
+ or op == '^'
+ or op == '%' then
+ infers['number'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '|'
+ or op == '&' then
+ infers['integer'] = true
+ return
+ end
return
end
end
@@ -289,6 +323,15 @@ function m.searchInfers(source)
for _, def in ipairs(defs) do
searchInfer(def, infers)
end
+ local id = linker.getID(source)
+ if id then
+ local link = linker.getLinkByID(source, id)
+ if link and link.sources then
+ for _, src in ipairs(link.sources) do
+ searchInfer(src, infers)
+ end
+ end
+ end
cleanInfers(infers)
return infers
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 9a6cce6f..95933b8d 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -244,8 +244,11 @@ end
]]
TEST 'string' [[
+---@return string
+local function f2() end
+
local function f()
- return string.sub()
+ return f2()
end
local <?x?> = f()