summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsumneko <sumneko@hotmail.com>2022-03-12 04:48:33 +0800
committersumneko <sumneko@hotmail.com>2022-03-12 04:48:33 +0800
commitc461d51bf40563c3b74c652d8789ae302ea7c9e1 (patch)
treeefae279e57602aa19bc993dfcd10bd5990959313
parenta4baf8a43b0b25414d8fe08cb898e150d420e705 (diff)
downloadlua-language-server-c461d51bf40563c3b74c652d8789ae302ea7c9e1.zip
update
-rw-r--r--script/parser/guide.lua21
-rw-r--r--script/vm/compiler.lua59
-rw-r--r--test/type_inference/init.lua27
3 files changed, 95 insertions, 12 deletions
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 73857149..f490b306 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -1182,4 +1182,25 @@ function m.isOOP(source)
return false
end
+local baseTypeMap = {
+ ['unknown'] = true,
+ ['any'] = true,
+ ['true'] = true,
+ ['false'] = true,
+ ['nil'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+ ['string'] = true,
+ ['table'] = true,
+ ['function'] = true,
+ ['thread'] = true,
+ ['userdata'] = true,
+}
+
+---@param str string
+---@return boolean
+function m.isBaseType(str)
+ return baseTypeMap[str] == true
+end
+
return m
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index c46519fd..9da8a55b 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -304,7 +304,7 @@ local function bindDocs(source)
end
end
if doc.type == 'doc.class' then
- if source.type == 'local'
+ if (source.type == 'local' and not isParam)
or (source._globalNode and guide.isSet(source)) then
hasFounded = true
nodeMgr.setNode(source, m.compileNode(doc))
@@ -407,7 +407,21 @@ local function selectNode(source, list, index)
return nodeMgr.setNode(source, result)
end
-local function setCallBackNode(source, call, callNode, fixIndex)
+---@param source parser.object
+---@param node vm.node
+---@return boolean
+local function isValidCallArgNode(source, node)
+ if source.type == 'function' then
+ return node.type == 'doc.type.function'
+ end
+ if source.type == 'table' then
+ return node.type == 'doc.type.table'
+ or (node.type == 'global' and node.cate == 'type' and not guide.isBaseType(node.name))
+ end
+ return false
+end
+
+local function setCallArgNode(source, call, callNode, fixIndex)
local valueMgr = require 'vm.value'
local myIndex
for i, arg in ipairs(call.args) do
@@ -416,6 +430,7 @@ local function setCallBackNode(source, call, callNode, fixIndex)
break
end
end
+
local eventIndex = 1
local eventArg = call.args[eventIndex + fixIndex]
if eventArg and eventArg.dummy then
@@ -423,11 +438,12 @@ local function setCallBackNode(source, call, callNode, fixIndex)
eventArg = call.args[eventIndex + fixIndex]
end
local eventMap = valueMgr.getLiterals(eventArg)
+
for n in nodeMgr.eachNode(callNode) do
if n.type == 'function' then
local arg = n.args[myIndex]
for fn in nodeMgr.eachNode(m.compileNode(arg)) do
- if fn.type == 'doc.type.function' then
+ if isValidCallArgNode(source, fn) then
nodeMgr.setNode(source, fn)
end
end
@@ -440,7 +456,7 @@ local function setCallBackNode(source, call, callNode, fixIndex)
or eventMap[event[1]] then
local arg = n.args[myIndex]
for fn in nodeMgr.eachNode(m.compileNode(arg)) do
- if fn.type == 'doc.type.function' then
+ if isValidCallArgNode(source, fn) then
nodeMgr.setNode(source, fn)
end
end
@@ -452,7 +468,6 @@ end
local compilerMap = util.switch()
: case 'nil'
: case 'boolean'
- : case 'table'
: case 'integer'
: case 'number'
: case 'string'
@@ -463,6 +478,32 @@ local compilerMap = util.switch()
: call(function (source)
nodeMgr.setNode(source, source)
end)
+ : case 'table'
+ : call(function (source)
+ nodeMgr.setNode(source, source)
+
+ if source.parent.type == 'callargs' then
+ local call = source.parent.parent
+ local callNode = m.compileNode(call.node)
+ setCallArgNode(source, call, callNode, 0)
+
+ if call.node.special == 'pcall'
+ or call.node.special == 'xpcall' then
+ local fixIndex = call.node.special == 'pcall' and 1 or 2
+ callNode = m.compileNode(call.args[1])
+ setCallArgNode(source, call, callNode, fixIndex)
+ end
+ end
+
+ if source.parent.type == 'setglobal'
+ or source.parent.type == 'setlocal'
+ or source.parent.type == 'tablefield'
+ or source.parent.type == 'tableindex'
+ or source.parent.type == 'setfield'
+ or source.parent.type == 'setindex' then
+ nodeMgr.setNode(source, m.compileNode(source.parent))
+ end
+ end)
: case 'function'
: call(function (source)
nodeMgr.setNode(source, source)
@@ -479,13 +520,13 @@ local compilerMap = util.switch()
if source.parent.type == 'callargs' then
local call = source.parent.parent
local callNode = m.compileNode(call.node)
- setCallBackNode(source, call, callNode, 0)
+ setCallArgNode(source, call, callNode, 0)
if call.node.special == 'pcall'
or call.node.special == 'xpcall' then
local fixIndex = call.node.special == 'pcall' and 1 or 2
callNode = m.compileNode(call.args[1])
- setCallBackNode(source, call, callNode, fixIndex)
+ setCallArgNode(source, call, callNode, fixIndex)
end
end
end)
@@ -578,6 +619,10 @@ local compilerMap = util.switch()
if source.value then
nodeMgr.setNode(source, m.compileNode(source.value))
end
+
+ m.compileByParentNode(source.parent, guide.getKeyName(source), function (src)
+ nodeMgr.setNode(source, m.compileNode(src))
+ end)
end)
: case 'field'
: case 'method'
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 7d3e3280..bdd13760 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -17,7 +17,8 @@ local function getSource(pos)
or source.type == 'getglobal'
or source.type == 'field'
or source.type == 'method'
- or source.type == 'function' then
+ or source.type == 'function'
+ or source.type == 'table' then
result = source
end
end)
@@ -1210,14 +1211,14 @@ local x
<?x?> = 1
]]
-TEST 'any' [[
+TEST 'unknown' [[
---@return number
local function f(x)
local <?y?> = x()
end
]]
-TEST 'any' [[
+TEST 'unknown' [[
local mt
---@return number
@@ -1226,7 +1227,7 @@ function mt:f() end
local <?v?> = mt()
]]
-TEST 'any' [[
+TEST 'unknown' [[
local <?mt?>
---@class X
@@ -1240,7 +1241,7 @@ local mt
function mt:f(<?x?>) end
]]
-TEST 'any' [[
+TEST 'unknown' [[
local <?mt?>
---@type number
@@ -1270,6 +1271,22 @@ mt:loop(function (<?i?>)
end)
]]
+TEST 'C' [[
+---@class D
+---@field y integer # D comment
+
+---@class C
+---@field x integer # C comment
+---@field d D
+
+---@param c C
+local function f(c) end
+
+f <?{?>
+ x = ,
+}
+]]
+
TEST 'integer' [[
---@class D
---@field y integer # D comment