diff options
-rw-r--r-- | script/core/infer.lua | 35 | ||||
-rw-r--r-- | script/core/linker.lua | 76 | ||||
-rw-r--r-- | script/core/searcher.lua | 5 | ||||
-rw-r--r-- | script/parser/guide.lua | 46 | ||||
-rw-r--r-- | script/vm/vm.lua | 8 | ||||
-rw-r--r-- | test/definition/luadoc.lua | 2 | ||||
-rw-r--r-- | test/type_inference/init.lua | 12 |
7 files changed, 144 insertions, 40 deletions
diff --git a/script/core/infer.lua b/script/core/infer.lua index fe097915..c3dc17c6 100644 --- a/script/core/infer.lua +++ b/script/core/infer.lua @@ -98,15 +98,15 @@ end local function searchInferOfValue(value, infers) if value.type == 'string' then infers['string'] = true - return + return true end if value.type == 'boolean' then infers['boolean'] = true - return + return true end if value.type == 'table' then infers['table'] = true - return + return true end if value.type == 'number' then if math.type(value[1]) == 'integer' then @@ -114,20 +114,25 @@ local function searchInferOfValue(value, infers) else infers['number'] = true end - return + return true + end + if value.type == 'nil' then + infers['nil'] = true + return true end if value.type == 'function' then infers['function'] = true - return + return true end if value.type == 'unary' then searchInferOfUnary(value, infers) - return + return true end if value.type == 'binary' then searchInferOfBinary(value, infers) - return + return true end + return false end local function searchLiteralOfValue(value, literals) @@ -180,6 +185,10 @@ end ---@param infers string[] ---@return string function m.viewInfers(infers) + -- 如果有显性的 any ,则直接显示为 any + if infers['any'] then + return 'any' + end local count = 0 for infer in pairs(infers) do count = count + 1 @@ -188,7 +197,8 @@ function m.viewInfers(infers) for i = count + 1, #infers do infers[i] = nil end - if #infers == 0 then + -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any + if count == 0 then return 'any' end table.sort(infers) @@ -202,6 +212,9 @@ local function searchInfer(source, infers) if bindClassOrType(source) then return end + if searchInferOfValue(source, infers) then + return + end local value = searcher.getObjectValue(source) if value then searchInferOfValue(value, infers) @@ -214,7 +227,7 @@ local function searchInfer(source, infers) end return end - -- X.a + -- X.a -> table if source.next and source.next.node == source then if source.next.type == 'setfield' or source.next.type == 'setindex' @@ -223,6 +236,10 @@ local function searchInfer(source, infers) end return end + -- return XX + if source.parent.type == 'return' then + infers['any'] = true + end end local function searchLiteral(source, literals) diff --git a/script/core/linker.lua b/script/core/linker.lua index 622561db..d9f3630a 100644 --- a/script/core/linker.lua +++ b/script/core/linker.lua @@ -94,6 +94,13 @@ local function getKey(source) return nil, nil elseif source.type == 'function' then return source.start, nil + elseif source.type == 'string' then + return '', nil + elseif source.type == 'integer' + or source.type == 'number' + or source.type == 'boolean' + or source.type == 'nil' then + return source.start, nil elseif source.type == '...' then return source.start, nil elseif source.type == 'select' then @@ -162,6 +169,15 @@ local function checkMode(source) if source.type == 'function' then return 'f:' end + if source.type == 'string' then + return 'str:' + end + if source.type == 'number' + or source.type == 'integer' + or source.type == 'boolean' + or source.type == 'nil' then + return 'l:' + end if source.type == 'call' then return 'c:' end @@ -232,6 +248,10 @@ local function getID(source) local current = source local index = 0 while true do + if current.type == 'paren' then + current = current.exp + goto CONTINUE + end local id, node = getKey(current) if not id then break @@ -245,6 +265,7 @@ local function getID(source) if current.special == '_G' then break end + ::CONTINUE:: end if index == 0 then source._id = false @@ -431,23 +452,15 @@ function m.compileLink(source) end getLink(id).call = source -- 将 call 映射到 node#1 上 - do - local select1ID = ('%s%s%s%s'):format( - nodeID, - SPLIT_CHAR, - RETURN_INDEX_CHAR, - 1 - ) - pushForward(id, select1ID) - end + local callID = ('%s%s%s%s'):format( + nodeID, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + 1 + ) + pushForward(id, callID) -- 将setmetatable映射到 param1 以及 param2.__index 上 if node.special == 'setmetatable' then - local callID = ('%s%s%s%s'):format( - nodeID, - SPLIT_CHAR, - RETURN_INDEX_CHAR, - 1 - ) local tblID = getID(source.args and source.args[1]) local metaID = getID(source.args and source.args[2]) local indexID @@ -468,20 +481,41 @@ function m.compileLink(source) end if source.type == 'select' then if source.vararg.type == 'call' then - local nodeID = getID(source.vararg.node) + local call = source.vararg + local node = call.node + local nodeID = getID(node) if not nodeID then return end -- 将call的返回值接收映射到函数返回值上 - local callID = ('%s%s%s%s'):format( + local callXID = ('%s%s%s%s'):format( nodeID, SPLIT_CHAR, RETURN_INDEX_CHAR, source.index ) - pushForward(id, callID) - pushBackward(callID, id) - getLink(id).call = source.vararg + pushForward(id, callXID) + pushBackward(callXID, id) + getLink(id).call = call + if node.special == 'pcall' + or node.special == 'xpcall' then + local index = source.index - 1 + if index <= 0 then + return + end + local funcID = call.args and getID(call.args[1]) + if not funcID then + return + end + local funcXID = ('%s%s%s%s'):format( + funcID, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + index + ) + pushForward(id, funcXID) + pushBackward(funcXID, id) + end end end if source.type == 'doc.type.function' then @@ -700,6 +734,8 @@ function m.compileLinks(source) m.pushSource(src) m.compileLink(src) end) + -- Special rule: ('').XX -> stringlib.XX + pushForward('str:', 'dn:stringlib') return Linkers end diff --git a/script/core/searcher.lua b/script/core/searcher.lua index fc1c08d1..c869f456 100644 --- a/script/core/searcher.lua +++ b/script/core/searcher.lua @@ -95,6 +95,11 @@ function m.pushResult(status, mode, source) results[#results+1] = source end end + if parent.type == 'return' then + if linker.getID(source) ~= status.id then + results[#results+1] = source + end + end elseif mode == 'field' then end end diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 048a15bb..a838a42e 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -687,4 +687,50 @@ function m.lineData(lines, row) return lines[row] end +function m.isSet(source) + local tp = source.type + if tp == 'setglobal' + or tp == 'local' + or tp == 'setlocal' + or tp == 'setfield' + or tp == 'setmethod' + or tp == 'setindex' + or tp == 'tablefield' + or tp == 'tableindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawset' then + return true + end + end + return false +end + +function m.isGet(source) + local tp = source.type + if tp == 'getglobal' + or tp == 'getlocal' + or tp == 'getfield' + or tp == 'getmethod' + or tp == 'getindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawget' then + return true + end + end + return false +end + +function m.getSpecial(source) + if not source then + return nil + end + return source.special +end + return m diff --git a/script/vm/vm.lua b/script/vm/vm.lua index b7eb1cde..ebd0102b 100644 --- a/script/vm/vm.lua +++ b/script/vm/vm.lua @@ -4,15 +4,11 @@ local files = require 'files' local timer = require 'timer' local setmetatable = setmetatable -local assert = assert -local require = require -local type = type local running = coroutine.running local ipairs = ipairs local log = log local xpcall = xpcall local mathHuge = math.huge -local collectgarbage = collectgarbage _ENV = nil @@ -63,10 +59,6 @@ function m.getArgInfo(source) return nil end -function m.getSpecial(source) - return guide.getSpecial(source) -end - function m.getKeyName(source) if not source then return nil diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua index 14f24552..d0d95847 100644 --- a/test/definition/luadoc.lua +++ b/test/definition/luadoc.lua @@ -113,7 +113,7 @@ print(<?f?>) TEST [[ local function f() - return 1 + return <!1!> end ---@class Class diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 0a041433..9a6cce6f 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -171,6 +171,11 @@ string.sub = function () end ]] TEST 'function' [[ +---@class stringlib +local string + +string.sub = function () end + _VERSION = 'Lua 5.4' <?x?> = _VERSION.sub @@ -217,8 +222,11 @@ end _, <?y?> = pcall(x) ]] -TEST 'oslib' [[ -local <?os?> = require 'os' +TEST 'integer' [[ +local function x() + return 1 +end +_, <?y?> = xpcall(x) ]] TEST 'string|table' [[ |