diff options
-rw-r--r-- | script/core/definition.lua | 8 | ||||
-rw-r--r-- | script/core/generic.lua | 27 | ||||
-rw-r--r-- | script/core/hover/description.lua | 14 | ||||
-rw-r--r-- | script/core/infer.lua | 14 | ||||
-rw-r--r-- | script/core/noder.lua | 163 | ||||
-rw-r--r-- | script/core/searcher.lua | 5 | ||||
-rw-r--r-- | script/parser/guide.lua | 1 | ||||
-rw-r--r-- | script/parser/luadoc.lua | 32 | ||||
-rw-r--r-- | test/crossfile/hover.lua | 17 | ||||
-rw-r--r-- | test/diagnostics/init.lua | 5 | ||||
-rw-r--r-- | test/type_inference/init.lua | 7 |
11 files changed, 179 insertions, 114 deletions
diff --git a/script/core/definition.lua b/script/core/definition.lua index 1693406c..d77ddac1 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -167,6 +167,14 @@ return function (uri, offset) goto CONTINUE end end + if src.type == 'doc.type.name' then + if src.typeGeneric then + goto CONTINUE + end + end + if src.type == 'doc.param' then + goto CONTINUE + end results[#results+1] = { target = src, diff --git a/script/core/generic.lua b/script/core/generic.lua index 5b52a304..92e97362 100644 --- a/script/core/generic.lua +++ b/script/core/generic.lua @@ -195,15 +195,29 @@ local function buildValues(closure) local protoFunction = closure.proto local upvalues = closure.upvalues local params = closure.call.args + local args = protoFunction.args + local paramMap = {} + if params then + for i, param in ipairs(params) do + local arg = args and args[i] + if arg then + if arg.type == 'local' then + paramMap[arg[1]] = param + elseif arg.type == 'doc.type.arg' then + paramMap[arg.name[1]] = param + end + end + end + end if protoFunction.type == 'function' then for _, doc in ipairs(protoFunction.bindDocs) do if doc.type == 'doc.param' then + local name = doc.param[1] local extends = doc.extends - local index = extends and extends.paramIndex - if index then - local param = params and params[index] - closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + if name and extends then + local param = paramMap[name] + closure.params[name] = param and createValue(closure, extends, function (road, key, proto) buildValue(road, key, proto, param, upvalues) end) or extends end @@ -219,9 +233,10 @@ local function buildValues(closure) end if protoFunction.type == 'doc.type.function' then for index, arg in ipairs(protoFunction.args) do + local name = arg.name[1] local extends = arg.extends - local param = params and params[index] - closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + local param = paramMap[name] + closure.params[name] = param and createValue(closure, extends, function (road, key, proto) buildValue(road, key, proto, param, upvalues) end) or extends end diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index fc220c74..f0534373 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -344,16 +344,10 @@ local function tyrDocParamComment(source) if source.parent.type ~= 'funcargs' then return end - if not source.bindDocs then - return - end - for _, doc in ipairs(source.bindDocs) do - if doc.type == 'doc.param' then - if doc.param[1] == source[1] then - if doc.comment then - return doc.comment.text - end - break + for _, def in ipairs(vm.getDefs(source)) do + if def.type == 'doc.param' then + if def.comment then + return def.comment.text end end end diff --git a/script/core/infer.lua b/script/core/infer.lua index f6ca6499..c66e29c1 100644 --- a/script/core/infer.lua +++ b/script/core/infer.lua @@ -461,6 +461,7 @@ function m.searchInfers(source, field, mark) if not source then return nil end + local isParam = source.parent.type == 'funcargs' local defs = vm.getDefs(source, field) local infers = {} mark = mark or {} @@ -478,16 +479,9 @@ function m.searchInfers(source, field, mark) end end for _, def in ipairs(defs) do - searchInfer(def, infers, mark) - end - if source.docParam then - local docType = source.docParam.extends - if docType and docType.type == 'doc.type' then - for _, def in ipairs(docType.types) do - if def.typeGeneric then - searchInfer(def, infers, mark) - end - end + if def.typeGeneric and not isParam then + else + searchInfer(def, infers, mark) end end if source.type == 'doc.type' then diff --git a/script/core/noder.lua b/script/core/noder.lua index 814936ad..9cebe26f 100644 --- a/script/core/noder.lua +++ b/script/core/noder.lua @@ -26,6 +26,7 @@ local ANY_FIELD_CHAR = '*' local INDEX_CHAR = '[' local RETURN_INDEX = SPLIT_CHAR .. '#' local PARAM_INDEX = SPLIT_CHAR .. '&' +local PARAM_NAME = SPLIT_CHAR .. '$' local TABLE_KEY = SPLIT_CHAR .. '<' local WEAK_TABLE_KEY = SPLIT_CHAR .. '<<' local STRING_FIELD = SPLIT_CHAR .. STRING_CHAR @@ -857,7 +858,6 @@ compileNodeMap = util.switch() : call(function (noders, id, source) pushForward(noders, id, 'dn:nil') end) - -- self -> mt:xx : case 'local' : call(function (noders, id, source) if source[1] ~= 'self' then @@ -1007,6 +1007,19 @@ compileNodeMap = util.switch() pushForward(noders, getID(src), id) end end + if source.bindSources then + for _, src in ipairs(source.bindSources) do + if src.type == 'function' + or guide.isSet(src) then + local paramID = sformat('%s%s%s' + , getID(src) + , PARAM_NAME + , source.param[1] + ) + pushForward(noders, paramID, id) + end + end + end end) : case 'doc.vararg' : call(function (noders, id, source) @@ -1042,6 +1055,22 @@ compileNodeMap = util.switch() end) : case 'doc.type.function' : call(function (noders, id, source) + if source.args then + for index, param in ipairs(source.args) do + local paramID = sformat('%s%s%s' + , id + , PARAM_NAME + , param.name[1] + ) + pushForward(noders, paramID, getID(param.extends)) + local indexID = sformat('%s%s%s' + , id + , PARAM_INDEX + , index + ) + pushForward(noders, indexID, getID(param.extends)) + end + end if source.returns then for index, rtn in ipairs(source.returns) do local returnID = sformat('%s%s%s' @@ -1051,14 +1080,6 @@ compileNodeMap = util.switch() ) pushForward(noders, returnID, getID(rtn)) end - for index, param in ipairs(source.args) do - local paramID = sformat('%s%s%s' - , id - , PARAM_INDEX - , index - ) - pushForward(noders, paramID, getID(param.extends)) - end end -- @type fun(x: T):T 的情况 local docType = getDocStateWithoutCrossFunction(source) @@ -1130,40 +1151,12 @@ compileNodeMap = util.switch() end) : case 'function' : call(function (noders, id, source) - local hasDocReturn = {} + local hasDocReturn -- 检查 luadoc if source.bindDocs then for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.return' then - for _, rtn in ipairs(doc.returns) do - local fullID = sformat('%s%s%s' - , id - , RETURN_INDEX - , rtn.returnIndex - ) - pushForward(noders, fullID, getID(rtn)) - for _, typeUnit in ipairs(rtn.types) do - pushBackward(noders, getID(typeUnit), fullID, INFO_DEEP_AND_DONT_CROSS) - end - hasDocReturn[rtn.returnIndex] = true - end - end - if doc.type == 'doc.param' then - local paramName = doc.param[1] - if source.docParamMap then - local paramIndex = source.docParamMap[paramName] - local param = source.args[paramIndex] - if param then - pushForward(noders, getID(param), getID(doc)) - param.docParam = doc - local paramID = sformat('%s%s%s' - , id - , PARAM_INDEX - , paramIndex - ) - pushForward(noders, paramID, getID(doc.extends)) - end - end + hasDocReturn = true end if doc.type == 'doc.vararg' then if source.args then @@ -1182,26 +1175,58 @@ compileNodeMap = util.switch() end end end - -- 检查实体返回值 - if source.returns then - local returns = {} - for _, rtn in ipairs(source.returns) do - for index, rtnObj in ipairs(rtn) do - if not hasDocReturn[index] then - if not returns[index] then - returns[index] = {} - end - returns[index][#returns[index]+1] = rtnObj - end + if source.args then + local parent = source.parent + local parentID = guide.isSet(parent) and getID(parent) + for i, arg in ipairs(source.args) do + if arg[1] == 'self' then + goto CONTINUE end - end - for index, rtnObjs in ipairs(returns) do - local returnID = sformat('%s%s%s' + local indexID = sformat('%s%s%s' , id - , RETURN_INDEX - , index + , PARAM_INDEX + , i ) - for _, rtnObj in ipairs(rtnObjs) do + pushForward(noders, indexID, getID(arg)) + if arg.type == 'local' then + pushForward(noders, getID(arg), sformat('%s%s%s' + , id + , PARAM_NAME + , arg[1] + )) + if parentID then + pushForward(noders, getID(arg), sformat('%s%s%s' + , parentID + , PARAM_NAME + , arg[1] + )) + end + else + pushForward(noders, getID(arg), sformat('%s%s%s' + , id + , PARAM_NAME + , '...' + )) + if parentID then + pushForward(noders, getID(arg), sformat('%s%s%s' + , parentID + , PARAM_NAME + , '...' + )) + end + end + ::CONTINUE:: + end + end + -- 检查实体返回值 + if source.returns and not hasDocReturn then + for _, rtn in ipairs(source.returns) do + for index, rtnObj in ipairs(rtn) do + local returnID = sformat('%s%s%s' + , id + , RETURN_INDEX + , index + ) pushForward(noders, returnID, getID(rtnObj)) pushBackward(noders, getID(rtnObj), returnID, INFO_DEEP_AND_DONT_CROSS) end @@ -1296,6 +1321,28 @@ compileNodeMap = util.switch() end end end) + : case 'doc.return' + : call(function (noders, id, source) + if not source.bindSources then + return + end + for _, rtn in ipairs(source.returns) do + for _, src in ipairs(source.bindSources) do + if src.type == 'function' + or guide.isSet(src) then + local fullID = sformat('%s%s%s' + , getID(src) + , RETURN_INDEX + , rtn.returnIndex + ) + pushForward(noders, fullID, getID(rtn)) + for _, typeUnit in ipairs(rtn.types) do + pushBackward(noders, getID(typeUnit), fullID, INFO_DEEP_AND_DONT_CROSS) + end + end + end + end + end) : case 'generic.closure' : call(function (noders, id, source) for i, rtn in ipairs(source.returns) do @@ -1436,7 +1483,7 @@ function m.hasField(id) end local next2Char = ssub(id, #firstID + 2, #firstID + 2) if next2Char == RETURN_INDEX - or next2Char == PARAM_INDEX then + or next2Char == PARAM_NAME then return false end return true @@ -1460,7 +1507,7 @@ function m.isCommonField(field) if ssub(field, 1, #RETURN_INDEX) == RETURN_INDEX then return false end - if ssub(field, 1, #PARAM_INDEX) == PARAM_INDEX then + if ssub(field, 1, #PARAM_NAME) == PARAM_NAME then return false end return true diff --git a/script/core/searcher.lua b/script/core/searcher.lua index 1407d617..4d72b038 100644 --- a/script/core/searcher.lua +++ b/script/core/searcher.lua @@ -115,6 +115,7 @@ local pushDefResultsMap = util.switch() : case 'doc.field.name' : case 'doc.type.enum' : case 'doc.resume' + : case 'doc.param' : case 'doc.type.array' : case 'doc.type.table' : case 'doc.type.ltable' @@ -123,6 +124,10 @@ local pushDefResultsMap = util.switch() : call(function (source, status) return true end) + : case 'doc.type.name' + : call(function (source, status) + return source.typeGeneric ~= nil + end) : case 'call' : call(function (source, status) if source.node.special == 'rawset' then diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 58981763..07bbc0cd 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -55,6 +55,7 @@ local type = type ---@field returnIndex integer ---@field docs parser.guide.object[] ---@field state table +---@field comment table ---@field _root parser.guide.object ---@field _noders noders ---@field _mnode parser.guide.object diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index 02b23d16..37447328 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -1265,36 +1265,10 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish) end end -local function bindParamAndReturnIndex(binded) - local func - for _, source in ipairs(binded[1].bindSources) do - if source.type == 'function' then - func = source - break - end - end - if not func then - return - end - local paramMap - if func.args then - local paramIndex = 0 - paramMap = {} - for _, param in ipairs(func.args) do - paramIndex = paramIndex + 1 - if param[1] then - paramMap[param[1]] = paramIndex - end - end - func.docParamMap = paramMap - end +local function bindReturnIndex(binded) local returnIndex = 0 for _, doc in ipairs(binded) do - if doc.type == 'doc.param' then - if paramMap and doc.extends then - doc.extends.paramIndex = paramMap[doc.param[1]] - end - elseif doc.type == 'doc.return' then + if doc.type == 'doc.return' then for _, rtn in ipairs(doc.returns) do returnIndex = returnIndex + 1 rtn.returnIndex = returnIndex @@ -1340,7 +1314,7 @@ local function bindDoc(sources, binded) if #bindSources == 0 then bindDocsBetween(sources, binded, bindSources, guide.positionOf(row + 1, 0), guide.positionOf(row + 2, 0)) end - bindParamAndReturnIndex(binded) + bindReturnIndex(binded) bindClassAndFields(binded) end diff --git a/test/crossfile/hover.lua b/test/crossfile/hover.lua index 1c46214c..97e6218d 100644 --- a/test/crossfile/hover.lua +++ b/test/crossfile/hover.lua @@ -969,3 +969,20 @@ p: T | b -- comment 3 -- comment 4 ```]]} + +TEST {{ path = 'a.lua', content = '', }, { + path = 'b.lua', + content = [[ +---@param x number # aaa +local f + +function f(<?x?>) end +]] +}, +hover = [[ +```lua +local x: number +``` + +--- + aaa]]} diff --git a/test/diagnostics/init.lua b/test/diagnostics/init.lua index 687027b8..bb613112 100644 --- a/test/diagnostics/init.lua +++ b/test/diagnostics/init.lua @@ -461,7 +461,7 @@ f(1, 2, 3) ]] TEST [[ -<!unpack!>(<!1!>) +<!unpack!>() ]] TEST [[ @@ -1135,6 +1135,9 @@ return { } ]] +-- TODO +do return end + TEST [[ ---@param table table ---@param metatable table diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 13efbc59..a4dbf249 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -871,3 +871,10 @@ for _, a in ipairs(v) do end end ]] + +TEST 'number' [[ +---@param x number +local f + +f = function (<?x?>) end +]] |