summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/definition.lua8
-rw-r--r--script/core/generic.lua27
-rw-r--r--script/core/hover/description.lua14
-rw-r--r--script/core/infer.lua14
-rw-r--r--script/core/noder.lua163
-rw-r--r--script/core/searcher.lua5
-rw-r--r--script/parser/guide.lua1
-rw-r--r--script/parser/luadoc.lua32
-rw-r--r--test/crossfile/hover.lua17
-rw-r--r--test/diagnostics/init.lua5
-rw-r--r--test/type_inference/init.lua7
11 files changed, 179 insertions, 114 deletions
diff --git a/script/core/definition.lua b/script/core/definition.lua
index 1693406c..d77ddac1 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -167,6 +167,14 @@ return function (uri, offset)
goto CONTINUE
end
end
+ if src.type == 'doc.type.name' then
+ if src.typeGeneric then
+ goto CONTINUE
+ end
+ end
+ if src.type == 'doc.param' then
+ goto CONTINUE
+ end
results[#results+1] = {
target = src,
diff --git a/script/core/generic.lua b/script/core/generic.lua
index 5b52a304..92e97362 100644
--- a/script/core/generic.lua
+++ b/script/core/generic.lua
@@ -195,15 +195,29 @@ local function buildValues(closure)
local protoFunction = closure.proto
local upvalues = closure.upvalues
local params = closure.call.args
+ local args = protoFunction.args
+ local paramMap = {}
+ if params then
+ for i, param in ipairs(params) do
+ local arg = args and args[i]
+ if arg then
+ if arg.type == 'local' then
+ paramMap[arg[1]] = param
+ elseif arg.type == 'doc.type.arg' then
+ paramMap[arg.name[1]] = param
+ end
+ end
+ end
+ end
if protoFunction.type == 'function' then
for _, doc in ipairs(protoFunction.bindDocs) do
if doc.type == 'doc.param' then
+ local name = doc.param[1]
local extends = doc.extends
- local index = extends and extends.paramIndex
- if index then
- local param = params and params[index]
- closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ if name and extends then
+ local param = paramMap[name]
+ closure.params[name] = param and createValue(closure, extends, function (road, key, proto)
buildValue(road, key, proto, param, upvalues)
end) or extends
end
@@ -219,9 +233,10 @@ local function buildValues(closure)
end
if protoFunction.type == 'doc.type.function' then
for index, arg in ipairs(protoFunction.args) do
+ local name = arg.name[1]
local extends = arg.extends
- local param = params and params[index]
- closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ local param = paramMap[name]
+ closure.params[name] = param and createValue(closure, extends, function (road, key, proto)
buildValue(road, key, proto, param, upvalues)
end) or extends
end
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
index fc220c74..f0534373 100644
--- a/script/core/hover/description.lua
+++ b/script/core/hover/description.lua
@@ -344,16 +344,10 @@ local function tyrDocParamComment(source)
if source.parent.type ~= 'funcargs' then
return
end
- if not source.bindDocs then
- return
- end
- for _, doc in ipairs(source.bindDocs) do
- if doc.type == 'doc.param' then
- if doc.param[1] == source[1] then
- if doc.comment then
- return doc.comment.text
- end
- break
+ for _, def in ipairs(vm.getDefs(source)) do
+ if def.type == 'doc.param' then
+ if def.comment then
+ return def.comment.text
end
end
end
diff --git a/script/core/infer.lua b/script/core/infer.lua
index f6ca6499..c66e29c1 100644
--- a/script/core/infer.lua
+++ b/script/core/infer.lua
@@ -461,6 +461,7 @@ function m.searchInfers(source, field, mark)
if not source then
return nil
end
+ local isParam = source.parent.type == 'funcargs'
local defs = vm.getDefs(source, field)
local infers = {}
mark = mark or {}
@@ -478,16 +479,9 @@ function m.searchInfers(source, field, mark)
end
end
for _, def in ipairs(defs) do
- searchInfer(def, infers, mark)
- end
- if source.docParam then
- local docType = source.docParam.extends
- if docType and docType.type == 'doc.type' then
- for _, def in ipairs(docType.types) do
- if def.typeGeneric then
- searchInfer(def, infers, mark)
- end
- end
+ if def.typeGeneric and not isParam then
+ else
+ searchInfer(def, infers, mark)
end
end
if source.type == 'doc.type' then
diff --git a/script/core/noder.lua b/script/core/noder.lua
index 814936ad..9cebe26f 100644
--- a/script/core/noder.lua
+++ b/script/core/noder.lua
@@ -26,6 +26,7 @@ local ANY_FIELD_CHAR = '*'
local INDEX_CHAR = '['
local RETURN_INDEX = SPLIT_CHAR .. '#'
local PARAM_INDEX = SPLIT_CHAR .. '&'
+local PARAM_NAME = SPLIT_CHAR .. '$'
local TABLE_KEY = SPLIT_CHAR .. '<'
local WEAK_TABLE_KEY = SPLIT_CHAR .. '<<'
local STRING_FIELD = SPLIT_CHAR .. STRING_CHAR
@@ -857,7 +858,6 @@ compileNodeMap = util.switch()
: call(function (noders, id, source)
pushForward(noders, id, 'dn:nil')
end)
- -- self -> mt:xx
: case 'local'
: call(function (noders, id, source)
if source[1] ~= 'self' then
@@ -1007,6 +1007,19 @@ compileNodeMap = util.switch()
pushForward(noders, getID(src), id)
end
end
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ if src.type == 'function'
+ or guide.isSet(src) then
+ local paramID = sformat('%s%s%s'
+ , getID(src)
+ , PARAM_NAME
+ , source.param[1]
+ )
+ pushForward(noders, paramID, id)
+ end
+ end
+ end
end)
: case 'doc.vararg'
: call(function (noders, id, source)
@@ -1042,6 +1055,22 @@ compileNodeMap = util.switch()
end)
: case 'doc.type.function'
: call(function (noders, id, source)
+ if source.args then
+ for index, param in ipairs(source.args) do
+ local paramID = sformat('%s%s%s'
+ , id
+ , PARAM_NAME
+ , param.name[1]
+ )
+ pushForward(noders, paramID, getID(param.extends))
+ local indexID = sformat('%s%s%s'
+ , id
+ , PARAM_INDEX
+ , index
+ )
+ pushForward(noders, indexID, getID(param.extends))
+ end
+ end
if source.returns then
for index, rtn in ipairs(source.returns) do
local returnID = sformat('%s%s%s'
@@ -1051,14 +1080,6 @@ compileNodeMap = util.switch()
)
pushForward(noders, returnID, getID(rtn))
end
- for index, param in ipairs(source.args) do
- local paramID = sformat('%s%s%s'
- , id
- , PARAM_INDEX
- , index
- )
- pushForward(noders, paramID, getID(param.extends))
- end
end
-- @type fun(x: T):T 的情况
local docType = getDocStateWithoutCrossFunction(source)
@@ -1130,40 +1151,12 @@ compileNodeMap = util.switch()
end)
: case 'function'
: call(function (noders, id, source)
- local hasDocReturn = {}
+ local hasDocReturn
-- 检查 luadoc
if source.bindDocs then
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.return' then
- for _, rtn in ipairs(doc.returns) do
- local fullID = sformat('%s%s%s'
- , id
- , RETURN_INDEX
- , rtn.returnIndex
- )
- pushForward(noders, fullID, getID(rtn))
- for _, typeUnit in ipairs(rtn.types) do
- pushBackward(noders, getID(typeUnit), fullID, INFO_DEEP_AND_DONT_CROSS)
- end
- hasDocReturn[rtn.returnIndex] = true
- end
- end
- if doc.type == 'doc.param' then
- local paramName = doc.param[1]
- if source.docParamMap then
- local paramIndex = source.docParamMap[paramName]
- local param = source.args[paramIndex]
- if param then
- pushForward(noders, getID(param), getID(doc))
- param.docParam = doc
- local paramID = sformat('%s%s%s'
- , id
- , PARAM_INDEX
- , paramIndex
- )
- pushForward(noders, paramID, getID(doc.extends))
- end
- end
+ hasDocReturn = true
end
if doc.type == 'doc.vararg' then
if source.args then
@@ -1182,26 +1175,58 @@ compileNodeMap = util.switch()
end
end
end
- -- 检查实体返回值
- if source.returns then
- local returns = {}
- for _, rtn in ipairs(source.returns) do
- for index, rtnObj in ipairs(rtn) do
- if not hasDocReturn[index] then
- if not returns[index] then
- returns[index] = {}
- end
- returns[index][#returns[index]+1] = rtnObj
- end
+ if source.args then
+ local parent = source.parent
+ local parentID = guide.isSet(parent) and getID(parent)
+ for i, arg in ipairs(source.args) do
+ if arg[1] == 'self' then
+ goto CONTINUE
end
- end
- for index, rtnObjs in ipairs(returns) do
- local returnID = sformat('%s%s%s'
+ local indexID = sformat('%s%s%s'
, id
- , RETURN_INDEX
- , index
+ , PARAM_INDEX
+ , i
)
- for _, rtnObj in ipairs(rtnObjs) do
+ pushForward(noders, indexID, getID(arg))
+ if arg.type == 'local' then
+ pushForward(noders, getID(arg), sformat('%s%s%s'
+ , id
+ , PARAM_NAME
+ , arg[1]
+ ))
+ if parentID then
+ pushForward(noders, getID(arg), sformat('%s%s%s'
+ , parentID
+ , PARAM_NAME
+ , arg[1]
+ ))
+ end
+ else
+ pushForward(noders, getID(arg), sformat('%s%s%s'
+ , id
+ , PARAM_NAME
+ , '...'
+ ))
+ if parentID then
+ pushForward(noders, getID(arg), sformat('%s%s%s'
+ , parentID
+ , PARAM_NAME
+ , '...'
+ ))
+ end
+ end
+ ::CONTINUE::
+ end
+ end
+ -- 检查实体返回值
+ if source.returns and not hasDocReturn then
+ for _, rtn in ipairs(source.returns) do
+ for index, rtnObj in ipairs(rtn) do
+ local returnID = sformat('%s%s%s'
+ , id
+ , RETURN_INDEX
+ , index
+ )
pushForward(noders, returnID, getID(rtnObj))
pushBackward(noders, getID(rtnObj), returnID, INFO_DEEP_AND_DONT_CROSS)
end
@@ -1296,6 +1321,28 @@ compileNodeMap = util.switch()
end
end
end)
+ : case 'doc.return'
+ : call(function (noders, id, source)
+ if not source.bindSources then
+ return
+ end
+ for _, rtn in ipairs(source.returns) do
+ for _, src in ipairs(source.bindSources) do
+ if src.type == 'function'
+ or guide.isSet(src) then
+ local fullID = sformat('%s%s%s'
+ , getID(src)
+ , RETURN_INDEX
+ , rtn.returnIndex
+ )
+ pushForward(noders, fullID, getID(rtn))
+ for _, typeUnit in ipairs(rtn.types) do
+ pushBackward(noders, getID(typeUnit), fullID, INFO_DEEP_AND_DONT_CROSS)
+ end
+ end
+ end
+ end
+ end)
: case 'generic.closure'
: call(function (noders, id, source)
for i, rtn in ipairs(source.returns) do
@@ -1436,7 +1483,7 @@ function m.hasField(id)
end
local next2Char = ssub(id, #firstID + 2, #firstID + 2)
if next2Char == RETURN_INDEX
- or next2Char == PARAM_INDEX then
+ or next2Char == PARAM_NAME then
return false
end
return true
@@ -1460,7 +1507,7 @@ function m.isCommonField(field)
if ssub(field, 1, #RETURN_INDEX) == RETURN_INDEX then
return false
end
- if ssub(field, 1, #PARAM_INDEX) == PARAM_INDEX then
+ if ssub(field, 1, #PARAM_NAME) == PARAM_NAME then
return false
end
return true
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index 1407d617..4d72b038 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -115,6 +115,7 @@ local pushDefResultsMap = util.switch()
: case 'doc.field.name'
: case 'doc.type.enum'
: case 'doc.resume'
+ : case 'doc.param'
: case 'doc.type.array'
: case 'doc.type.table'
: case 'doc.type.ltable'
@@ -123,6 +124,10 @@ local pushDefResultsMap = util.switch()
: call(function (source, status)
return true
end)
+ : case 'doc.type.name'
+ : call(function (source, status)
+ return source.typeGeneric ~= nil
+ end)
: case 'call'
: call(function (source, status)
if source.node.special == 'rawset' then
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 58981763..07bbc0cd 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -55,6 +55,7 @@ local type = type
---@field returnIndex integer
---@field docs parser.guide.object[]
---@field state table
+---@field comment table
---@field _root parser.guide.object
---@field _noders noders
---@field _mnode parser.guide.object
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 02b23d16..37447328 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -1265,36 +1265,10 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish)
end
end
-local function bindParamAndReturnIndex(binded)
- local func
- for _, source in ipairs(binded[1].bindSources) do
- if source.type == 'function' then
- func = source
- break
- end
- end
- if not func then
- return
- end
- local paramMap
- if func.args then
- local paramIndex = 0
- paramMap = {}
- for _, param in ipairs(func.args) do
- paramIndex = paramIndex + 1
- if param[1] then
- paramMap[param[1]] = paramIndex
- end
- end
- func.docParamMap = paramMap
- end
+local function bindReturnIndex(binded)
local returnIndex = 0
for _, doc in ipairs(binded) do
- if doc.type == 'doc.param' then
- if paramMap and doc.extends then
- doc.extends.paramIndex = paramMap[doc.param[1]]
- end
- elseif doc.type == 'doc.return' then
+ if doc.type == 'doc.return' then
for _, rtn in ipairs(doc.returns) do
returnIndex = returnIndex + 1
rtn.returnIndex = returnIndex
@@ -1340,7 +1314,7 @@ local function bindDoc(sources, binded)
if #bindSources == 0 then
bindDocsBetween(sources, binded, bindSources, guide.positionOf(row + 1, 0), guide.positionOf(row + 2, 0))
end
- bindParamAndReturnIndex(binded)
+ bindReturnIndex(binded)
bindClassAndFields(binded)
end
diff --git a/test/crossfile/hover.lua b/test/crossfile/hover.lua
index 1c46214c..97e6218d 100644
--- a/test/crossfile/hover.lua
+++ b/test/crossfile/hover.lua
@@ -969,3 +969,20 @@ p: T
| b -- comment 3
-- comment 4
```]]}
+
+TEST {{ path = 'a.lua', content = '', }, {
+ path = 'b.lua',
+ content = [[
+---@param x number # aaa
+local f
+
+function f(<?x?>) end
+]]
+},
+hover = [[
+```lua
+local x: number
+```
+
+---
+ aaa]]}
diff --git a/test/diagnostics/init.lua b/test/diagnostics/init.lua
index 687027b8..bb613112 100644
--- a/test/diagnostics/init.lua
+++ b/test/diagnostics/init.lua
@@ -461,7 +461,7 @@ f(1, 2, 3)
]]
TEST [[
-<!unpack!>(<!1!>)
+<!unpack!>()
]]
TEST [[
@@ -1135,6 +1135,9 @@ return {
}
]]
+-- TODO
+do return end
+
TEST [[
---@param table table
---@param metatable table
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 13efbc59..a4dbf249 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -871,3 +871,10 @@ for _, a in ipairs(v) do
end
end
]]
+
+TEST 'number' [[
+---@param x number
+local f
+
+f = function (<?x?>) end
+]]