diff options
-rw-r--r-- | script/core/completion/completion.lua | 38 | ||||
-rw-r--r-- | script/core/hover/label.lua | 9 | ||||
-rw-r--r-- | script/vm/compiler.lua | 106 | ||||
-rw-r--r-- | script/vm/infer.lua | 7 | ||||
-rw-r--r-- | test/completion/common.lua | 104 | ||||
-rw-r--r-- | test/type_inference/init.lua | 14 |
6 files changed, 170 insertions, 108 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index b6fd15ee..9454d6c2 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -311,7 +311,7 @@ local function checkLocal(state, word, position, results) if name:sub(1, 1) == '@' then goto CONTINUE end - if infer.getInfer(source):hasType 'function' then + if infer.getInfer(source):hasFunction() then for _, def in ipairs(vm.getDefs(source)) do if def.type == 'function' or def.type == 'doc.type.function' then @@ -419,7 +419,7 @@ local function checkModule(state, word, position, results) end local function checkFieldFromFieldToIndex(state, name, src, parent, word, startPos, position) - if name:match '^[%a_][%w_]*$' then + if type(name) == 'string' and name:match '^[%a_][%w_]*$' then return nil end local textEdit, additionalTextEdits @@ -492,7 +492,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent kind = define.CompletionItemKind.Function end buildFunction(results, src, value, oop, { - label = name, + label = tostring(name), kind = kind, match = name:match '^[^(]+', insertText = name:match '^[^(]+', @@ -506,7 +506,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent }) return end - if oop and not infer.getInfer(src):hasType 'function' then + if oop and not infer.getInfer(src):hasFunction() then return end local literal = guide.getLiteral(value) @@ -525,7 +525,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent textEdit, additionalTextEdits = checkFieldFromFieldToIndex(state, name, src, parent, word, startPos, position) end results[#results+1] = { - label = name, + label = tostring(name), kind = kind, deprecated = vm.isDeprecated(src) or nil, textEdit = textEdit, @@ -1132,7 +1132,7 @@ local function checkTypingEnum(state, position, defs, str, results) if def.type == 'doc.type.string' or def.type == 'doc.type.integer' then enums[#enums+1] = { - label = def[1], + label = util.viewLiteral(def[1]), description = def.comment and def.comment.text, kind = define.CompletionItemKind.EnumMember, } @@ -1340,7 +1340,8 @@ local function pushCallEnumsAndFuncs(source) local defs = vm.getDefs(source) local results = {} for _, def in ipairs(defs) do - if def.type == 'doc.type.string' then + if def.type == 'doc.type.string' + or def.type == 'doc.type.integer' then results[#results+1] = { label = util.viewLiteral(def[1]), description = def.comment, @@ -1391,20 +1392,16 @@ local function getCallEnumsAndFuncs(source, index, oop, call) return pushCallEnumsAndFuncs(arg.extends) end end - if source.type == 'doc.field.name' then + if source.type == 'doc.field' then local currentIndex = index if oop then currentIndex = index - 1 end - local class = source.parent.class - if not class then - return - end local results = {} local valueBeforeIndex = index > 1 and call.args[index - 1][1] for _, doc in ipairs(class.fields) do - if doc.field ~= source + if doc ~= source and doc.field[1] == source[1] then local indexType = currentIndex if not oop then @@ -1479,7 +1476,7 @@ local function checkTableLiteralField(state, position, tbl, fields, results) end end table.sort(fields, function (a, b) - return guide.getKeyName(a) < guide.getKeyName(b) + return tostring(guide.getKeyName(a)) < tostring(guide.getKeyName(b)) end) -- {$} local left = lookBackward.findWord(text, guide.positionToOffset(state, position)) @@ -1519,6 +1516,9 @@ local function tryCallArg(state, position, results) if arg and arg.type == 'function' then return end + if not arg then + arg = { type = 'dummy' } + end local defs = vm.getDefs(call.node) for _, def in ipairs(defs) do def = vm.getObjectValue(def) or def @@ -1774,9 +1774,15 @@ local function tryluaDocByErr(state, position, err, docState, results) end elseif err.type == 'LUADOC_MISS_TYPE_NAME' then for _, doc in ipairs(vm.getDocSets(state.uri)) do - if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then + if doc.type == 'doc.class' then + results[#results+1] = { + label = doc.class[1], + kind = define.CompletionItemKind.Class, + } + end + if doc.type == 'doc.alias' then results[#results+1] = { - label = doc[1], + label = doc.alias[1], kind = define.CompletionItemKind.Class, } end diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index 7f9f29a4..995c3294 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -119,16 +119,15 @@ local function asField(source) end local function asDocFieldName(source) - local name = source[1] - local docField = source.parent + local name = source.field[1] local class - for _, doc in ipairs(docField.bindGroup) do + for _, doc in ipairs(source.bindGroup) do if doc.type == 'doc.class' then class = doc break end end - local view = infer.getInfer(docField.extends):view() + local view = infer.getInfer(source.extends):view() if not class then return ('field ?.%s: %s'):format(name, view) end @@ -202,7 +201,7 @@ return function (source, oop) return asNumber(source) elseif source.type == 'doc.type.name' then return asDocTypeName(source) - elseif source.type == 'doc.field.name' then + elseif source.type == 'doc.field' then return asDocFieldName(source) end end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 61a6076f..d5d65dfb 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -483,16 +483,38 @@ local function isValidCallArgNode(source, node) return node.type == 'doc.type.table' or (node.type == 'global' and node.cate == 'type' and not guide.isBasicType(node.name)) end + if source.type == 'dummy' then + return true + end return false end -local function setCallArgNode(source, call, callNode, fixIndex) +---@param func parser.object +---@param index integer +---@return parser.object? +local function getFuncArg(func, index) + local args = func.args + if not args then + return nil + end + if args[index] then + return args[index] + end + local lastArg = args[#args] + if lastArg.type == '...' then + return lastArg + end + return nil +end + +local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) local valueMgr = require 'vm.value' - local myIndex - for i, arg in ipairs(call.args) do - if arg == source then - myIndex = i - fixIndex - break + if not myIndex then + for i, carg in ipairs(call.args) do + if carg == arg then + myIndex = i - fixIndex + break + end end end @@ -508,10 +530,10 @@ local function setCallArgNode(source, call, callNode, fixIndex) for n in nodeMgr.eachNode(callNode) do if n.type == 'function' then - local arg = n.args[myIndex] - for fn in nodeMgr.eachNode(m.compileNode(arg)) do - if isValidCallArgNode(source, fn) then - nodeMgr.setNode(source, fn) + local farg = getFuncArg(n, myIndex) + for fn in nodeMgr.eachNode(m.compileNode(farg)) do + if isValidCallArgNode(arg, fn) then + nodeMgr.setNode(arg, fn) end end end @@ -521,10 +543,10 @@ local function setCallArgNode(source, call, callNode, fixIndex) or not eventMap or event.type ~= 'doc.type.string' or eventMap[event[1]] then - local arg = n.args[myIndex] - for fn in nodeMgr.eachNode(m.compileNode(arg)) do - if isValidCallArgNode(source, fn) then - nodeMgr.setNode(source, fn) + local farg = getFuncArg(n, myIndex) + for fn in nodeMgr.eachNode(m.compileNode(farg)) do + if isValidCallArgNode(arg, fn) then + nodeMgr.setNode(arg, fn) end end end @@ -532,6 +554,18 @@ local function setCallArgNode(source, call, callNode, fixIndex) end end +function m.compileCallArg(arg, call, index) + local callNode = m.compileNode(call.node) + compileCallArgNode(arg, call, callNode, 0, index) + + if call.node.special == 'pcall' + or call.node.special == 'xpcall' then + local fixIndex = call.node.special == 'pcall' and 1 or 2 + callNode = m.compileNode(call.args[1]) + compileCallArgNode(arg, call, callNode, fixIndex, index) + end +end + local compilerSwitch = util.switch() : case 'nil' : case 'boolean' @@ -551,18 +585,11 @@ local compilerSwitch = util.switch() if source.parent.type == 'callargs' then local call = source.parent.parent - local callNode = m.compileNode(call.node) - setCallArgNode(source, call, callNode, 0) - - if call.node.special == 'pcall' - or call.node.special == 'xpcall' then - local fixIndex = call.node.special == 'pcall' and 1 or 2 - callNode = m.compileNode(call.args[1]) - setCallArgNode(source, call, callNode, fixIndex) - end + m.compileCallArg(source, call) end if source.parent.type == 'setglobal' + or source.parent.type == 'local' or source.parent.type == 'setlocal' or source.parent.type == 'tablefield' or source.parent.type == 'tableindex' @@ -586,15 +613,7 @@ local compilerSwitch = util.switch() -- table.sort(string[], function (<?x?>) end) if source.parent.type == 'callargs' then local call = source.parent.parent - local callNode = m.compileNode(call.node) - setCallArgNode(source, call, callNode, 0) - - if call.node.special == 'pcall' - or call.node.special == 'xpcall' then - local fixIndex = call.node.special == 'pcall' and 1 or 2 - callNode = m.compileNode(call.args[1]) - setCallArgNode(source, call, callNode, fixIndex) - end + m.compileCallArg(source, call) end end) : case 'paren' @@ -612,7 +631,11 @@ local compilerSwitch = util.switch() if source.ref and not hasMarkDoc then for _, ref in ipairs(source.ref) do if ref.type == 'setlocal' then - nodeMgr.setNode(source, m.compileNode(ref.value)) + if ref.value.type == 'table' then + nodeMgr.setNode(source, ref.value) + else + nodeMgr.setNode(source, m.compileNode(ref.value)) + end end end end @@ -623,7 +646,11 @@ local compilerSwitch = util.switch() end if source.value then if not hasMarkDoc or guide.isLiteral(source.value) then - nodeMgr.setNode(source, m.compileNode(source.value)) + if source.value.type == 'table' then + nodeMgr.setNode(source, source.value) + else + nodeMgr.setNode(source, m.compileNode(source.value)) + end end end -- function x.y(self, ...) --> function x:y(...) @@ -675,6 +702,11 @@ local compilerSwitch = util.switch() : case 'setindex' : call(function (source) compileByLocalID(source) + m.compileByParentNode(source.node, guide.getKeyName(source), function (src) + if src.type == 'doc.type.field' then + nodeMgr.setNode(source, m.compileNode(src)) + end + end) end) : case 'getfield' : case 'getmethod' @@ -695,7 +727,11 @@ local compilerSwitch = util.switch() if source.value then if not hasMarkDoc or guide.isLiteral(source.value) then - nodeMgr.setNode(source, m.compileNode(source.value)) + if source.value.type == 'table' then + nodeMgr.setNode(source, source.value) + else + nodeMgr.setNode(source, m.compileNode(source.value)) + end end end diff --git a/script/vm/infer.lua b/script/vm/infer.lua index c0daf00a..b436f9e1 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -229,6 +229,13 @@ function mt:hasClass() return self._hasClass == true end +---@return boolean +function mt:hasFunction() + self:_computeViews() + return self.views['function'] == true + or self._hasDocFunction == true +end + function mt:_computeViews() if self.views then return diff --git a/test/completion/common.lua b/test/completion/common.lua index 6fde16fe..4058e844 100644 --- a/test/completion/common.lua +++ b/test/completion/common.lua @@ -1526,7 +1526,7 @@ mt.<??> } TEST [[ ----@param x string | "'AAA'" | "'BBB'" | "'CCC'" +---@param x string | "AAA" | "BBB" | "CCC" function f(y, x) end @@ -1534,21 +1534,21 @@ f(1, <??>) ]] { { - label = "'AAA'", + label = '"AAA"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'BBB'", + label = '"BBB"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'CCC'", + label = '"CCC"', kind = define.CompletionItemKind.EnumMember, } } TEST [[ ----@param x string | "'AAA'" | "'BBB'" | "'CCC'" +---@param x string | "AAA" | "BBB" | "CCC" function f(y, x) end @@ -1556,21 +1556,21 @@ f(1,<??>) ]] { { - label = "'AAA'", + label = '"AAA"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'BBB'", + label = '"BBB"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'CCC'", + label = '"CCC"', kind = define.CompletionItemKind.EnumMember, } } TEST [[ ----@param x string | "'AAA'" | "'BBB'" | "'CCC'" +---@param x string | "AAA" | "BBB" | "CCC" function f(x) end @@ -1578,21 +1578,21 @@ f(<??>) ]] { { - label = "'AAA'", + label = '"AAA"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'BBB'", + label = '"BBB"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'CCC'", + label = '"CCC"', kind = define.CompletionItemKind.EnumMember, } } TEST [[ ----@alias Option string | "'AAA'" | "'BBB'" | "'CCC'" +---@alias Option string | "AAA" | "BBB" | "CCC" ---@param x Option function f(x) end @@ -1601,21 +1601,21 @@ f(<??>) ]] { { - label = "'AAA'", + label = '"AAA"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'BBB'", + label = '"BBB"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'CCC'", + label = '"CCC"', kind = define.CompletionItemKind.EnumMember, } } TEST [[ ----@param x string | "'AAA'" | "'BBB'" | "'CCC'" +---@param x string | "AAA" | "BBB" | "CCC" function f(x) end @@ -1646,10 +1646,10 @@ TEST [[ ---@alias XXXX ---comment 1 ---comment 1 ----| '1' +---| 1 ---comment 2 ---comment 2 ----| '2' +---| 2 ---@param x XXXX local function f(x) @@ -1675,10 +1675,10 @@ TEST [[ ---@alias XXXX ---comment 1 ---comment 1 ----| '1' +---| 1 ---comment 2 ---comment 2 ----| '2' +---| 2 ---@param x XXXX local function f(x) end @@ -1686,7 +1686,7 @@ end ---comment 3 ---comment 3 ----| '3' +---| 3 f(<??>) ]] @@ -1760,20 +1760,20 @@ global zzz: integer = 1 TEST [[ ---@param x string ----| "'选项1'" # 注释1 ----| "'选项2'" # 注释2 +---| "选项1" # 注释1 +---| "选项2" # 注释2 function f(x) end f(<??>) ]] { { - label = "'选项1'", + label = '"选项1"', kind = define.CompletionItemKind.EnumMember, description = '注释1', }, { - label = "'选项2'", + label = '"选项2"', kind = define.CompletionItemKind.EnumMember, description = '注释2', }, @@ -1792,49 +1792,49 @@ utf8.charpatter<??> } TEST [[ ----@type "'a'"|"'b'"|"'c'" +---@type "a"|"b"|"c" local x print(x == <??>) ]] { { - label = "'a'", + label = '"a"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'b'", + label = '"b"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'c'", + label = '"c"', kind = define.CompletionItemKind.EnumMember, }, } TEST [[ ----@type "'a'"|"'b'"|"'c'" +---@type "a"|"b"|"c" local x x = <??> ]] { { - label = "'a'", + label = '"a"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'b'", + label = '"b"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'c'", + label = '"c"', kind = define.CompletionItemKind.EnumMember, }, } TEST [[ ----@type "'a'"|"'b'"|"'c'" +---@type "a"|"b"|"c" local x print(x == '<??>') @@ -1858,7 +1858,7 @@ print(x == '<??>') } TEST [[ ----@type "'a'"|"'b'"|"'c'" +---@type "a"|"b"|"c" local x x = '<??>' @@ -2095,85 +2095,85 @@ field cc.aaa: number Cared['description'] = nil TEST [[ ----@type table<string, "'a'"|"'b'"|"'c'"> +---@type table<string, "a"|"b"|"c"> local x x.a = <??> ]] { { - label = "'a'", + label = '"a"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'b'", + label = '"b"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'c'", + label = '"c"', kind = define.CompletionItemKind.EnumMember, }, } TEST [[ ----@type table<string, "'a'"|"'b'"|"'c'"> +---@type table<string, "a"|"b"|"c"> local x x['a'] = <??> ]] { { - label = "'a'", + label = '"a"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'b'", + label = '"b"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'c'", + label = '"c"', kind = define.CompletionItemKind.EnumMember, }, } TEST [[ ----@type table<string, "'a'"|"'b'"|"'c'"> +---@type table<string, "a"|"b"|"c"> local x = { a = <??> } ]] { { - label = "'a'", + label = '"a"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'b'", + label = '"b"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'c'", + label = '"c"', kind = define.CompletionItemKind.EnumMember, }, } TEST [[ ----@type table<string, "'a'"|"'b'"|"'c'"> +---@type table<string, "a"|"b"|"c"> local x = { ['a'] = <??> } ]] { { - label = "'a'", + label = '"a"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'b'", + label = '"b"', kind = define.CompletionItemKind.EnumMember, }, { - label = "'c'", + label = '"c"', kind = define.CompletionItemKind.EnumMember, }, } diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 02e46b01..a6311e6e 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1367,3 +1367,17 @@ TEST 'function' [[ ---@overload fun() function <?f?>() end ]] + +TEST 'integer' [[ +---@type table<string, integer> +local t + +t.<?a?> +]] + +TEST '"a"|"b"|"c"' [[ +---@type table<string, "a"|"b"|"c"> +local t + +t.<?a?> +]] |