summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/infer.lua35
-rw-r--r--script/core/linker.lua76
-rw-r--r--script/core/searcher.lua5
-rw-r--r--script/parser/guide.lua46
-rw-r--r--script/vm/vm.lua8
-rw-r--r--test/definition/luadoc.lua2
-rw-r--r--test/type_inference/init.lua12
7 files changed, 144 insertions, 40 deletions
diff --git a/script/core/infer.lua b/script/core/infer.lua
index fe097915..c3dc17c6 100644
--- a/script/core/infer.lua
+++ b/script/core/infer.lua
@@ -98,15 +98,15 @@ end
local function searchInferOfValue(value, infers)
if value.type == 'string' then
infers['string'] = true
- return
+ return true
end
if value.type == 'boolean' then
infers['boolean'] = true
- return
+ return true
end
if value.type == 'table' then
infers['table'] = true
- return
+ return true
end
if value.type == 'number' then
if math.type(value[1]) == 'integer' then
@@ -114,20 +114,25 @@ local function searchInferOfValue(value, infers)
else
infers['number'] = true
end
- return
+ return true
+ end
+ if value.type == 'nil' then
+ infers['nil'] = true
+ return true
end
if value.type == 'function' then
infers['function'] = true
- return
+ return true
end
if value.type == 'unary' then
searchInferOfUnary(value, infers)
- return
+ return true
end
if value.type == 'binary' then
searchInferOfBinary(value, infers)
- return
+ return true
end
+ return false
end
local function searchLiteralOfValue(value, literals)
@@ -180,6 +185,10 @@ end
---@param infers string[]
---@return string
function m.viewInfers(infers)
+ -- 如果有显性的 any ,则直接显示为 any
+ if infers['any'] then
+ return 'any'
+ end
local count = 0
for infer in pairs(infers) do
count = count + 1
@@ -188,7 +197,8 @@ function m.viewInfers(infers)
for i = count + 1, #infers do
infers[i] = nil
end
- if #infers == 0 then
+ -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any
+ if count == 0 then
return 'any'
end
table.sort(infers)
@@ -202,6 +212,9 @@ local function searchInfer(source, infers)
if bindClassOrType(source) then
return
end
+ if searchInferOfValue(source, infers) then
+ return
+ end
local value = searcher.getObjectValue(source)
if value then
searchInferOfValue(value, infers)
@@ -214,7 +227,7 @@ local function searchInfer(source, infers)
end
return
end
- -- X.a
+ -- X.a -> table
if source.next and source.next.node == source then
if source.next.type == 'setfield'
or source.next.type == 'setindex'
@@ -223,6 +236,10 @@ local function searchInfer(source, infers)
end
return
end
+ -- return XX
+ if source.parent.type == 'return' then
+ infers['any'] = true
+ end
end
local function searchLiteral(source, literals)
diff --git a/script/core/linker.lua b/script/core/linker.lua
index 622561db..d9f3630a 100644
--- a/script/core/linker.lua
+++ b/script/core/linker.lua
@@ -94,6 +94,13 @@ local function getKey(source)
return nil, nil
elseif source.type == 'function' then
return source.start, nil
+ elseif source.type == 'string' then
+ return '', nil
+ elseif source.type == 'integer'
+ or source.type == 'number'
+ or source.type == 'boolean'
+ or source.type == 'nil' then
+ return source.start, nil
elseif source.type == '...' then
return source.start, nil
elseif source.type == 'select' then
@@ -162,6 +169,15 @@ local function checkMode(source)
if source.type == 'function' then
return 'f:'
end
+ if source.type == 'string' then
+ return 'str:'
+ end
+ if source.type == 'number'
+ or source.type == 'integer'
+ or source.type == 'boolean'
+ or source.type == 'nil' then
+ return 'l:'
+ end
if source.type == 'call' then
return 'c:'
end
@@ -232,6 +248,10 @@ local function getID(source)
local current = source
local index = 0
while true do
+ if current.type == 'paren' then
+ current = current.exp
+ goto CONTINUE
+ end
local id, node = getKey(current)
if not id then
break
@@ -245,6 +265,7 @@ local function getID(source)
if current.special == '_G' then
break
end
+ ::CONTINUE::
end
if index == 0 then
source._id = false
@@ -431,23 +452,15 @@ function m.compileLink(source)
end
getLink(id).call = source
-- 将 call 映射到 node#1 上
- do
- local select1ID = ('%s%s%s%s'):format(
- nodeID,
- SPLIT_CHAR,
- RETURN_INDEX_CHAR,
- 1
- )
- pushForward(id, select1ID)
- end
+ local callID = ('%s%s%s%s'):format(
+ nodeID,
+ SPLIT_CHAR,
+ RETURN_INDEX_CHAR,
+ 1
+ )
+ pushForward(id, callID)
-- 将setmetatable映射到 param1 以及 param2.__index 上
if node.special == 'setmetatable' then
- local callID = ('%s%s%s%s'):format(
- nodeID,
- SPLIT_CHAR,
- RETURN_INDEX_CHAR,
- 1
- )
local tblID = getID(source.args and source.args[1])
local metaID = getID(source.args and source.args[2])
local indexID
@@ -468,20 +481,41 @@ function m.compileLink(source)
end
if source.type == 'select' then
if source.vararg.type == 'call' then
- local nodeID = getID(source.vararg.node)
+ local call = source.vararg
+ local node = call.node
+ local nodeID = getID(node)
if not nodeID then
return
end
-- 将call的返回值接收映射到函数返回值上
- local callID = ('%s%s%s%s'):format(
+ local callXID = ('%s%s%s%s'):format(
nodeID,
SPLIT_CHAR,
RETURN_INDEX_CHAR,
source.index
)
- pushForward(id, callID)
- pushBackward(callID, id)
- getLink(id).call = source.vararg
+ pushForward(id, callXID)
+ pushBackward(callXID, id)
+ getLink(id).call = call
+ if node.special == 'pcall'
+ or node.special == 'xpcall' then
+ local index = source.index - 1
+ if index <= 0 then
+ return
+ end
+ local funcID = call.args and getID(call.args[1])
+ if not funcID then
+ return
+ end
+ local funcXID = ('%s%s%s%s'):format(
+ funcID,
+ SPLIT_CHAR,
+ RETURN_INDEX_CHAR,
+ index
+ )
+ pushForward(id, funcXID)
+ pushBackward(funcXID, id)
+ end
end
end
if source.type == 'doc.type.function' then
@@ -700,6 +734,8 @@ function m.compileLinks(source)
m.pushSource(src)
m.compileLink(src)
end)
+ -- Special rule: ('').XX -> stringlib.XX
+ pushForward('str:', 'dn:stringlib')
return Linkers
end
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index fc1c08d1..c869f456 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -95,6 +95,11 @@ function m.pushResult(status, mode, source)
results[#results+1] = source
end
end
+ if parent.type == 'return' then
+ if linker.getID(source) ~= status.id then
+ results[#results+1] = source
+ end
+ end
elseif mode == 'field' then
end
end
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 048a15bb..a838a42e 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -687,4 +687,50 @@ function m.lineData(lines, row)
return lines[row]
end
+function m.isSet(source)
+ local tp = source.type
+ if tp == 'setglobal'
+ or tp == 'local'
+ or tp == 'setlocal'
+ or tp == 'setfield'
+ or tp == 'setmethod'
+ or tp == 'setindex'
+ or tp == 'tablefield'
+ or tp == 'tableindex' then
+ return true
+ end
+ if tp == 'call' then
+ local special = m.getSpecial(source.node)
+ if special == 'rawset' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.isGet(source)
+ local tp = source.type
+ if tp == 'getglobal'
+ or tp == 'getlocal'
+ or tp == 'getfield'
+ or tp == 'getmethod'
+ or tp == 'getindex' then
+ return true
+ end
+ if tp == 'call' then
+ local special = m.getSpecial(source.node)
+ if special == 'rawget' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.getSpecial(source)
+ if not source then
+ return nil
+ end
+ return source.special
+end
+
return m
diff --git a/script/vm/vm.lua b/script/vm/vm.lua
index b7eb1cde..ebd0102b 100644
--- a/script/vm/vm.lua
+++ b/script/vm/vm.lua
@@ -4,15 +4,11 @@ local files = require 'files'
local timer = require 'timer'
local setmetatable = setmetatable
-local assert = assert
-local require = require
-local type = type
local running = coroutine.running
local ipairs = ipairs
local log = log
local xpcall = xpcall
local mathHuge = math.huge
-local collectgarbage = collectgarbage
_ENV = nil
@@ -63,10 +59,6 @@ function m.getArgInfo(source)
return nil
end
-function m.getSpecial(source)
- return guide.getSpecial(source)
-end
-
function m.getKeyName(source)
if not source then
return nil
diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua
index 14f24552..d0d95847 100644
--- a/test/definition/luadoc.lua
+++ b/test/definition/luadoc.lua
@@ -113,7 +113,7 @@ print(<?f?>)
TEST [[
local function f()
- return 1
+ return <!1!>
end
---@class Class
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 0a041433..9a6cce6f 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -171,6 +171,11 @@ string.sub = function () end
]]
TEST 'function' [[
+---@class stringlib
+local string
+
+string.sub = function () end
+
_VERSION = 'Lua 5.4'
<?x?> = _VERSION.sub
@@ -217,8 +222,11 @@ end
_, <?y?> = pcall(x)
]]
-TEST 'oslib' [[
-local <?os?> = require 'os'
+TEST 'integer' [[
+local function x()
+ return 1
+end
+_, <?y?> = xpcall(x)
]]
TEST 'string|table' [[