diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2024-02-26 10:58:00 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-26 10:58:00 +0800 |
commit | 73be83cacbe4a759a100a794e871c896a4692399 (patch) | |
tree | e761a59942ab07e6eab12fd7c260122b51a2ccf8 /script | |
parent | f388b95de3188e97d670bd6f1924325637446cf7 (diff) | |
parent | 87c83c38a3e1c4c617daeb104a0e31f5b1deaf1f (diff) | |
download | lua-language-server-73be83cacbe4a759a100a794e871c896a4692399.zip |
Merge pull request #2532 from fesily/automatic-infer-function-param-type
add infer function param type
Diffstat (limited to 'script')
-rw-r--r-- | script/client.lua | 4 | ||||
-rw-r--r-- | script/config/template.lua | 1 | ||||
-rw-r--r-- | script/core/command/autoRequire.lua | 5 | ||||
-rw-r--r-- | script/core/completion/completion.lua | 52 | ||||
-rw-r--r-- | script/core/diagnostics/undefined-doc-name.lua | 13 | ||||
-rw-r--r-- | script/core/highlight.lua | 4 | ||||
-rw-r--r-- | script/fs-utility.lua | 6 | ||||
-rw-r--r-- | script/gc.lua | 3 | ||||
-rw-r--r-- | script/json-edit.lua | 13 | ||||
-rw-r--r-- | script/parser/compile.lua | 27 | ||||
-rw-r--r-- | script/parser/guide.lua | 3 | ||||
-rw-r--r-- | script/parser/luadoc.lua | 2 | ||||
-rw-r--r-- | script/plugin.lua | 1 | ||||
-rw-r--r-- | script/timer.lua | 13 | ||||
-rw-r--r-- | script/vm/compiler.lua | 62 | ||||
-rw-r--r-- | script/vm/visible.lua | 10 | ||||
-rw-r--r-- | script/workspace/scope.lua | 6 |
17 files changed, 142 insertions, 83 deletions
diff --git a/script/client.lua b/script/client.lua index a8eda9b8..e328dc52 100644 --- a/script/client.lua +++ b/script/client.lua @@ -278,7 +278,7 @@ local function searchPatchInfo(cfg, rawKey) } end ----@param uri uri +---@param uri? uri ---@param cfg table ---@param change config.change ---@return json.patch? @@ -330,7 +330,7 @@ local function makeConfigPatch(uri, cfg, change) return nil end ----@param uri uri +---@param uri? uri ---@param path string ---@param changes config.change[] ---@return string? diff --git a/script/config/template.lua b/script/config/template.lua index 2a30d2ea..49907419 100644 --- a/script/config/template.lua +++ b/script/config/template.lua @@ -397,6 +397,7 @@ local template = { ['Lua.type.castNumberToInteger'] = Type.Boolean >> true, ['Lua.type.weakUnionCheck'] = Type.Boolean >> false, ['Lua.type.weakNilCheck'] = Type.Boolean >> false, + ['Lua.type.inferParamType'] = Type.Boolean >> false, ['Lua.doc.privateName'] = Type.Array(Type.String), ['Lua.doc.protectedName'] = Type.Array(Type.String), ['Lua.doc.packageName'] = Type.Array(Type.String), diff --git a/script/core/command/autoRequire.lua b/script/core/command/autoRequire.lua index a96cc918..9f3ff929 100644 --- a/script/core/command/autoRequire.lua +++ b/script/core/command/autoRequire.lua @@ -132,6 +132,7 @@ end ---@async return function (data) + ---@type uri local uri = data.uri local target = data.target local name = data.name @@ -158,5 +159,7 @@ return function (data) end local offset, fmt = findInsertRow(uri) - applyAutoRequire(uri, offset, name, requireName, fmt) + if offset and fmt then + applyAutoRequire(uri, offset, name, requireName, fmt) + end end diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index acb3adbe..d047dd56 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -147,6 +147,9 @@ end local function findParent(state, position) local text = state.lua + if not text then + return + end local offset = guide.positionToOffset(state, position) for i = offset, 1, -1 do local char = text:sub(i, i) @@ -675,6 +678,7 @@ local function checkGlobal(state, word, startPos, position, parent, oop, results end ---@async +---@param parent parser.object local function checkField(state, word, start, position, parent, oop, results) if parent.tag == '_ENV' or parent.special == '_G' then local globals = vm.getGlobalSets(state.uri, 'variable') @@ -955,8 +959,7 @@ local function checkFunctionArgByDocParam(state, word, startPos, results) end end -local function isAfterLocal(state, startPos) - local text = state.lua +local function isAfterLocal(state, text, startPos) local offset = guide.positionToOffset(state, startPos) local pos = lookBackward.skipSpace(text, offset) local word = lookBackward.findWord(text, pos) @@ -965,6 +968,8 @@ end local function collectRequireNames(mode, myUri, literal, source, smark, position, results) local collect = {} + local source_start = source and smark and (source.start + #smark) or position + local source_finish = source and smark and (source.finish - #smark) or position if mode == 'require' then for uri in files.eachFile(myUri) do if myUri == uri then @@ -978,8 +983,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position if not collect[info.name] then collect[info.name] = { textEdit = { - start = smark and (source.start + #smark) or position, - finish = smark and (source.finish - #smark) or position, + start = source_start, + finish = source_finish, newText = smark and info.name or util.viewString(info.name), }, path = relative, @@ -1006,8 +1011,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position if not collect[open] then collect[open] = { textEdit = { - start = smark and (source.start + #smark) or position, - finish = smark and (source.finish - #smark) or position, + start = source_start, + finish = source_finish, newText = smark and open or util.viewString(open), }, path = path, @@ -1034,8 +1039,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position if not collect[path] then collect[path] = { textEdit = { - start = smark and (source.start + #smark) or position, - finish = smark and (source.finish - #smark) or position, + start = source_start, + finish = source_finish, newText = smark and path or util.viewString(path), } } @@ -1097,6 +1102,9 @@ end local function checkLenPlusOne(state, position, results) local text = state.lua + if not text then + return + end guide.eachSourceContain(state.ast, position, function (source) if source.type == 'getindex' or source.type == 'setindex' then @@ -1392,6 +1400,9 @@ end local function checkEqualEnum(state, position, results) local text = state.lua + if not text then + return + end local start = lookBackward.findTargetSymbol(text, guide.positionToOffset(state, position), '=') if not start then return @@ -1493,6 +1504,9 @@ local function tryWord(state, position, triggerCharacter, results) return end local text = state.lua + if not text then + return + end local offset = guide.positionToOffset(state, position) local finish = lookBackward.skipSpace(text, offset) local word, start = lookBackward.findWord(text, offset) @@ -1518,7 +1532,7 @@ local function tryWord(state, position, triggerCharacter, results) checkProvideLocal(state, word, startPos, results) checkFunctionArgByDocParam(state, word, startPos, results) else - local afterLocal = isAfterLocal(state, startPos) + local afterLocal = isAfterLocal(state, text, startPos) local stop = checkKeyWord(state, startPos, position, word, hasSpace, afterLocal, results) if stop then return @@ -1530,8 +1544,10 @@ local function tryWord(state, position, triggerCharacter, results) checkLocal(state, word, startPos, results) checkTableField(state, word, startPos, results) local env = guide.getENV(state.ast, startPos) - checkGlobal(state, word, startPos, position, env, false, results) - checkModule(state, word, startPos, results) + if env then + checkGlobal(state, word, startPos, position, env, false, results) + checkModule(state, word, startPos, results) + end end end end @@ -1592,6 +1608,9 @@ end local function checkTableLiteralField(state, position, tbl, fields, results) local text = state.lua + if not text then + return + end local mark = {} for _, field in ipairs(tbl) do if field.type == 'tablefield' @@ -1610,9 +1629,11 @@ local function checkTableLiteralField(state, position, tbl, fields, results) local left = lookBackward.findWord(text, guide.positionToOffset(state, position)) if not left then local pos = lookBackward.findAnyOffset(text, guide.positionToOffset(state, position)) - local char = text:sub(pos, pos) - if char == '{' or char == ',' or char == ';' then - left = '' + if pos then + local char = text:sub(pos, pos) + if char == '{' or char == ',' or char == ';' then + left = '' + end end end if left then @@ -1801,6 +1822,7 @@ local function getluaDocByContain(state, position) return result end +---@return parser.state.err?, parser.object? local function getluaDocByErr(state, start, position) local targetError for _, err in ipairs(state.errs) do @@ -2008,7 +2030,7 @@ local function tryluaDocByErr(state, position, err, docState, results) for _, doc in ipairs(vm.getDocSets(state.uri)) do if doc.type == 'doc.class' and not used[doc.class[1]] - and doc.class[1] ~= docState.class[1] then + and docState and doc.class[1] ~= docState.class[1] then used[doc.class[1]] = true results[#results+1] = { label = doc.class[1], diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index 3c8ed469..1c55f3bf 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -13,16 +13,6 @@ return function (uri, callback) return end - local function hasNameOfGeneric(name, source) - if not source.typeGeneric then - return false - end - if not source.typeGeneric[name] then - return false - end - return true - end - guide.eachSource(state.ast.docs, function (source) if source.type ~= 'doc.extends.name' and source.type ~= 'doc.type.name' then @@ -35,8 +25,7 @@ return function (uri, callback) if name == '...' or name == '_' or name == 'self' then return end - if #vm.getDocSets(uri, name) > 0 - or hasNameOfGeneric(name, source) then + if #vm.getDocSets(uri, name) > 0 then return end callback { diff --git a/script/core/highlight.lua b/script/core/highlight.lua index 80088680..72214672 100644 --- a/script/core/highlight.lua +++ b/script/core/highlight.lua @@ -63,7 +63,7 @@ local function checkInIf(state, source, text, position) local endA = endB - #'end' + 1 if position >= source.finish - #'end' and position <= source.finish - and text:sub(endA, endB) == 'end' then + and text and text:sub(endA, endB) == 'end' then return true end -- 检查每个子模块 @@ -83,7 +83,7 @@ local function makeIf(state, source, text, callback) -- end local endB = guide.positionToOffset(state, source.finish) local endA = endB - #'end' + 1 - if text:sub(endA, endB) == 'end' then + if text and text:sub(endA, endB) == 'end' then callback(source.finish - #'end', source.finish) end -- 每个子模块 diff --git a/script/fs-utility.lua b/script/fs-utility.lua index 9a45b1cc..8d2bf319 100644 --- a/script/fs-utility.lua +++ b/script/fs-utility.lua @@ -128,6 +128,7 @@ function dfs:__div(filename) return new end +---@package function dfs:_open(index) local paths = split(self.path, '[/\\]') local current = self.files @@ -147,6 +148,7 @@ function dfs:_open(index) return current end +---@package function dfs:_filename() return self.path:match '[^/\\]+$' end @@ -291,6 +293,7 @@ local function fsIsDirectory(path, option) if path.type == 'dummy' then return path:isDirectory() end + ---@cast path -dummyfs local status = fs.symlink_status(path):type() return status == 'directory' end @@ -347,6 +350,7 @@ local function fsSave(path, text, option) return false end if path.type == 'dummy' then + ---@cast path -fs.path local dir = path:_open(-2) if not dir then option.err[#option.err+1] = '无法打开:' .. path:string() @@ -385,6 +389,7 @@ local function fsLoad(path, option) return nil end else + ---@cast path -dummyfs local text, err = m.loadFile(path) if text then return text @@ -407,6 +412,7 @@ local function fsCopy(source, target, option) end return fsSave(target, sourceText, option) else + ---@cast source -dummyfs if target.type == 'dummy' then local sourceText, err = m.loadFile(source) if not sourceText then diff --git a/script/gc.lua b/script/gc.lua index ff22195e..92739585 100644 --- a/script/gc.lua +++ b/script/gc.lua @@ -1,12 +1,13 @@ local util = require 'utility' ---@class gc ----@field _list table +---@field package _list table local mt = {} mt.__index = mt mt.type = 'gc' mt._removed = false +---@package mt._max = 10 local function destroyGCObject(obj) diff --git a/script/json-edit.lua b/script/json-edit.lua index 30a55250..efa1216f 100644 --- a/script/json-edit.lua +++ b/script/json-edit.lua @@ -384,6 +384,7 @@ end local JsonEmpty = function () end +---@return {s: integer, d:integer, f:integer, v: any} local function decode_ast(str) if type(str) ~= "string" then error("expected argument of type string, got " .. type(str)) @@ -607,7 +608,11 @@ function OP.add(str, option, path, value) end local ast = decode_ast(str) if ast.v == JsonEmpty then - local pathlst = split_path(path) + local pathlst, err = split_path(path) + if not pathlst then + error(err) + return + end value = add_prefix(value, pathlst) return json.beautify(value, option) end @@ -674,7 +679,11 @@ function OP.replace(str, option, path, value) end local ast = decode_ast(str) if ast.v == JsonEmpty then - local pathlst = split_path(path) + local pathlst, err = split_path(path) + if not pathlst then + error(err) + return + end value = add_prefix(value, pathlst) return json.beautify(value, option) end diff --git a/script/parser/compile.lua b/script/parser/compile.lua index 5321d9b8..8dd772db 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -239,6 +239,16 @@ local LocalLimit = 200 local parseExp, parseAction +---@class parser.state.err +---@field type string +---@field start? parser.position +---@field finish? parser.position +---@field info? table +---@field fix? table +---@field version? string[]|string +---@field level? string | 'Error' | 'Warning' + +---@type fun(err:parser.state.err):parser.state.err|nil local pushError local function addSpecial(name, obj) @@ -711,6 +721,7 @@ local function parseLocalAttrs() return attrs end +---@param obj table local function createLocal(obj, attrs) obj.type = 'local' obj.effect = obj.finish @@ -1709,6 +1720,9 @@ local function parseTable() end local function addDummySelf(node, call) + if not node then + return + end if node.type ~= 'getmethod' then return end @@ -1736,6 +1750,9 @@ local function checkAmbiguityCall(call, parenPos) return end local node = call.node + if not node then + return + end local nodeRow = guide.rowColOf(node.finish) local callRow = guide.rowColOf(parenPos) if nodeRow == callRow then @@ -2470,7 +2487,10 @@ local function parseExpUnit() local node = parseName() if node then - return parseSimple(resolveName(node), false) + local nameNode = resolveName(node) + if nameNode then + return parseSimple(nameNode, false) + end end return nil @@ -3421,6 +3441,7 @@ local function parseFor() forStateVars = 3 LocalCount = LocalCount + forStateVars if name then + ---@cast name parser.object local loc = createLocal(name) loc.parent = action action.finish = name.finish @@ -3523,7 +3544,9 @@ local function parseFor() list.range = lastName and lastName.range or inRight action.keys = list for i = 1, #list do - local loc = createLocal(list[i]) + local obj = list[i] + ---@cast obj parser.object + local loc = createLocal(obj) loc.parent = action loc.effect = action.finish end diff --git a/script/parser/guide.lua b/script/parser/guide.lua index fd779da0..ac7a5ce0 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -74,9 +74,12 @@ local type = type ---@field hasBreak? true ---@field hasExit? true ---@field [integer] parser.object|any +---@field dot { type: string, start: integer, finish: integer } +---@field colon { type: string, start: integer, finish: integer } ---@field package _root parser.object ---@field package _eachCache? parser.object[] ---@field package _isGlobal? boolean +---@field package _typeCache? parser.object[][] ---@class guide ---@field debugMode boolean diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index aec8994b..25ff71c1 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -1620,7 +1620,7 @@ local function trimTailComment(text) and comment:find '[\'"%]]%s*$' then local state = compile(comment:gsub('^%s+', ''), 'String') if state and state.ast then - comment = state.ast[1] + comment = state.ast[1] --[[@as string]] end end return util.trim(comment) diff --git a/script/plugin.lua b/script/plugin.lua index b77511ff..35a1da5b 100644 --- a/script/plugin.lua +++ b/script/plugin.lua @@ -62,6 +62,7 @@ end function m.getVmPlugin(uri) local scp = scope.getScope(uri) + ---@type pluginInterfaces local interfaces = scp:get('pluginInterfaces') if not interfaces then return diff --git a/script/timer.lua b/script/timer.lua index 09bbce0f..14d33f6a 100644 --- a/script/timer.lua +++ b/script/timer.lua @@ -14,6 +14,7 @@ local curIndex = 0 local tarFrame = 0 local fwFrame = 0 local freeQueue = {} +---@type (timer|false)[][] local timer = {} local function allocQueue() @@ -101,9 +102,10 @@ end local m = {} ---@class timer ----@field _onTimer? fun(self: timer) ----@field _timeoutFrame integer ----@field _timeout integer +---@field package _onTimer? fun(self: timer) +---@field package _timeoutFrame integer +---@field package _timeout integer +---@field package _timerCount integer local mt = {} mt.__index = mt mt.type = 'timer' @@ -119,6 +121,7 @@ function mt:__call() end function mt:remove() + ---@package self._removed = true end @@ -126,7 +129,9 @@ function mt:pause() if self._removed or self._pauseRemaining then return end + ---@package self._pauseRemaining = getRemaining(self) + ---@package self._running = false local ti = self._timeoutFrame local q = timer[ti] @@ -145,6 +150,7 @@ function mt:resume() return end local timeout = self._pauseRemaining + ---@package self._pauseRemaining = nil mTimeout(self, timeout) end @@ -163,6 +169,7 @@ function mt:restart() end end end + ---@package self._running = false mTimeout(self, self._timeout) end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 2253c83a..fc8f7c52 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1031,6 +1031,7 @@ local function compileForVars(source, target) return false end +---@param func parser.object ---@param source parser.object local function compileFunctionParam(func, source) -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number @@ -1050,33 +1051,31 @@ local function compileFunctionParam(func, source) end end end - if func.parent.type == 'local' then + + local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType') + if derviationParam and func.parent.type == 'local' and func.parent.ref then local refs = func.parent.ref - local findCall - if refs then - for i, ref in ipairs(refs) do - if ref.parent.type == 'call' then - findCall = ref.parent - break - end + local found + for _, ref in ipairs(refs) do + if ref.parent.type ~= 'call' then + goto continue end - end - if findCall and findCall.args then - local index - for i, arg in ipairs(source.parent) do - if arg == source then - index = i - break - end + local caller = ref.parent + if not caller.args then + goto continue end - if index then - local callerArg = findCall.args[index] - if callerArg then - vm.setNode(source, vm.compileNode(callerArg)) - return true + for index, arg in ipairs(source.parent) do + if arg == source then + local callerArg = caller.args[index] + if callerArg then + vm.setNode(source, vm.compileNode(callerArg)) + finded = true + end end end + ::continue:: end + return finded end end @@ -1121,24 +1120,9 @@ local function compileLocal(source) end if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then local func = source.parent.parent - -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number - local funcNode = vm.compileNode(func) - local hasDocArg - for n in funcNode:eachObject() do - if n.type == 'doc.type.function' then - for index, arg in ipairs(n.args) do - if func.args[index] == source then - local argNode = vm.compileNode(arg) - for an in argNode:eachObject() do - if an.type ~= 'doc.generic.name' then - vm.setNode(source, an) - end - end - hasDocArg = true - end - end - end - end + local vmPlugin = plugin.getVmPlugin(guide.getUri(source)) + local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source) + or compileFunctionParam(func, source) if not hasDocArg then vm.setNode(source, vm.declareGlobal('type', 'any')) end diff --git a/script/vm/visible.lua b/script/vm/visible.lua index d13ecf1f..0f486d6b 100644 --- a/script/vm/visible.lua +++ b/script/vm/visible.lua @@ -31,6 +31,10 @@ local function getVisibleType(source) source._visibleType = 'protected' return 'protected' end + if doc.type == 'doc.package' then + source._visibleType = 'package' + return 'package' + end end end @@ -50,6 +54,12 @@ local function getVisibleType(source) source._visibleType = 'protected' return 'protected' end + + local packageNames = config.get(uri, 'Lua.doc.packageName') + if #packageNames > 0 and glob.glob(packageNames)(fieldName) then + source._visibleType = 'package' + return 'package' + end end source._visibleType = 'public' diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua index 789b5f81..cfdfdc90 100644 --- a/script/workspace/scope.lua +++ b/script/workspace/scope.lua @@ -235,11 +235,11 @@ function m.getLinkedScope(uri) return nil end ----@param uri uri +---@param uri? uri ---@return scope function m.getScope(uri) - return m.getFolder(uri) - or m.getLinkedScope(uri) + return uri and (m.getFolder(uri) + or m.getLinkedScope(uri)) or m.fallback end |