summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2024-02-26 10:58:00 +0800
committerGitHub <noreply@github.com>2024-02-26 10:58:00 +0800
commit73be83cacbe4a759a100a794e871c896a4692399 (patch)
treee761a59942ab07e6eab12fd7c260122b51a2ccf8 /script
parentf388b95de3188e97d670bd6f1924325637446cf7 (diff)
parent87c83c38a3e1c4c617daeb104a0e31f5b1deaf1f (diff)
downloadlua-language-server-73be83cacbe4a759a100a794e871c896a4692399.zip
Merge pull request #2532 from fesily/automatic-infer-function-param-type
add infer function param type
Diffstat (limited to 'script')
-rw-r--r--script/client.lua4
-rw-r--r--script/config/template.lua1
-rw-r--r--script/core/command/autoRequire.lua5
-rw-r--r--script/core/completion/completion.lua52
-rw-r--r--script/core/diagnostics/undefined-doc-name.lua13
-rw-r--r--script/core/highlight.lua4
-rw-r--r--script/fs-utility.lua6
-rw-r--r--script/gc.lua3
-rw-r--r--script/json-edit.lua13
-rw-r--r--script/parser/compile.lua27
-rw-r--r--script/parser/guide.lua3
-rw-r--r--script/parser/luadoc.lua2
-rw-r--r--script/plugin.lua1
-rw-r--r--script/timer.lua13
-rw-r--r--script/vm/compiler.lua62
-rw-r--r--script/vm/visible.lua10
-rw-r--r--script/workspace/scope.lua6
17 files changed, 142 insertions, 83 deletions
diff --git a/script/client.lua b/script/client.lua
index a8eda9b8..e328dc52 100644
--- a/script/client.lua
+++ b/script/client.lua
@@ -278,7 +278,7 @@ local function searchPatchInfo(cfg, rawKey)
}
end
----@param uri uri
+---@param uri? uri
---@param cfg table
---@param change config.change
---@return json.patch?
@@ -330,7 +330,7 @@ local function makeConfigPatch(uri, cfg, change)
return nil
end
----@param uri uri
+---@param uri? uri
---@param path string
---@param changes config.change[]
---@return string?
diff --git a/script/config/template.lua b/script/config/template.lua
index 2a30d2ea..49907419 100644
--- a/script/config/template.lua
+++ b/script/config/template.lua
@@ -397,6 +397,7 @@ local template = {
['Lua.type.castNumberToInteger'] = Type.Boolean >> true,
['Lua.type.weakUnionCheck'] = Type.Boolean >> false,
['Lua.type.weakNilCheck'] = Type.Boolean >> false,
+ ['Lua.type.inferParamType'] = Type.Boolean >> false,
['Lua.doc.privateName'] = Type.Array(Type.String),
['Lua.doc.protectedName'] = Type.Array(Type.String),
['Lua.doc.packageName'] = Type.Array(Type.String),
diff --git a/script/core/command/autoRequire.lua b/script/core/command/autoRequire.lua
index a96cc918..9f3ff929 100644
--- a/script/core/command/autoRequire.lua
+++ b/script/core/command/autoRequire.lua
@@ -132,6 +132,7 @@ end
---@async
return function (data)
+ ---@type uri
local uri = data.uri
local target = data.target
local name = data.name
@@ -158,5 +159,7 @@ return function (data)
end
local offset, fmt = findInsertRow(uri)
- applyAutoRequire(uri, offset, name, requireName, fmt)
+ if offset and fmt then
+ applyAutoRequire(uri, offset, name, requireName, fmt)
+ end
end
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua
index acb3adbe..d047dd56 100644
--- a/script/core/completion/completion.lua
+++ b/script/core/completion/completion.lua
@@ -147,6 +147,9 @@ end
local function findParent(state, position)
local text = state.lua
+ if not text then
+ return
+ end
local offset = guide.positionToOffset(state, position)
for i = offset, 1, -1 do
local char = text:sub(i, i)
@@ -675,6 +678,7 @@ local function checkGlobal(state, word, startPos, position, parent, oop, results
end
---@async
+---@param parent parser.object
local function checkField(state, word, start, position, parent, oop, results)
if parent.tag == '_ENV' or parent.special == '_G' then
local globals = vm.getGlobalSets(state.uri, 'variable')
@@ -955,8 +959,7 @@ local function checkFunctionArgByDocParam(state, word, startPos, results)
end
end
-local function isAfterLocal(state, startPos)
- local text = state.lua
+local function isAfterLocal(state, text, startPos)
local offset = guide.positionToOffset(state, startPos)
local pos = lookBackward.skipSpace(text, offset)
local word = lookBackward.findWord(text, pos)
@@ -965,6 +968,8 @@ end
local function collectRequireNames(mode, myUri, literal, source, smark, position, results)
local collect = {}
+ local source_start = source and smark and (source.start + #smark) or position
+ local source_finish = source and smark and (source.finish - #smark) or position
if mode == 'require' then
for uri in files.eachFile(myUri) do
if myUri == uri then
@@ -978,8 +983,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[info.name] then
collect[info.name] = {
textEdit = {
- start = smark and (source.start + #smark) or position,
- finish = smark and (source.finish - #smark) or position,
+ start = source_start,
+ finish = source_finish,
newText = smark and info.name or util.viewString(info.name),
},
path = relative,
@@ -1006,8 +1011,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[open] then
collect[open] = {
textEdit = {
- start = smark and (source.start + #smark) or position,
- finish = smark and (source.finish - #smark) or position,
+ start = source_start,
+ finish = source_finish,
newText = smark and open or util.viewString(open),
},
path = path,
@@ -1034,8 +1039,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[path] then
collect[path] = {
textEdit = {
- start = smark and (source.start + #smark) or position,
- finish = smark and (source.finish - #smark) or position,
+ start = source_start,
+ finish = source_finish,
newText = smark and path or util.viewString(path),
}
}
@@ -1097,6 +1102,9 @@ end
local function checkLenPlusOne(state, position, results)
local text = state.lua
+ if not text then
+ return
+ end
guide.eachSourceContain(state.ast, position, function (source)
if source.type == 'getindex'
or source.type == 'setindex' then
@@ -1392,6 +1400,9 @@ end
local function checkEqualEnum(state, position, results)
local text = state.lua
+ if not text then
+ return
+ end
local start = lookBackward.findTargetSymbol(text, guide.positionToOffset(state, position), '=')
if not start then
return
@@ -1493,6 +1504,9 @@ local function tryWord(state, position, triggerCharacter, results)
return
end
local text = state.lua
+ if not text then
+ return
+ end
local offset = guide.positionToOffset(state, position)
local finish = lookBackward.skipSpace(text, offset)
local word, start = lookBackward.findWord(text, offset)
@@ -1518,7 +1532,7 @@ local function tryWord(state, position, triggerCharacter, results)
checkProvideLocal(state, word, startPos, results)
checkFunctionArgByDocParam(state, word, startPos, results)
else
- local afterLocal = isAfterLocal(state, startPos)
+ local afterLocal = isAfterLocal(state, text, startPos)
local stop = checkKeyWord(state, startPos, position, word, hasSpace, afterLocal, results)
if stop then
return
@@ -1530,8 +1544,10 @@ local function tryWord(state, position, triggerCharacter, results)
checkLocal(state, word, startPos, results)
checkTableField(state, word, startPos, results)
local env = guide.getENV(state.ast, startPos)
- checkGlobal(state, word, startPos, position, env, false, results)
- checkModule(state, word, startPos, results)
+ if env then
+ checkGlobal(state, word, startPos, position, env, false, results)
+ checkModule(state, word, startPos, results)
+ end
end
end
end
@@ -1592,6 +1608,9 @@ end
local function checkTableLiteralField(state, position, tbl, fields, results)
local text = state.lua
+ if not text then
+ return
+ end
local mark = {}
for _, field in ipairs(tbl) do
if field.type == 'tablefield'
@@ -1610,9 +1629,11 @@ local function checkTableLiteralField(state, position, tbl, fields, results)
local left = lookBackward.findWord(text, guide.positionToOffset(state, position))
if not left then
local pos = lookBackward.findAnyOffset(text, guide.positionToOffset(state, position))
- local char = text:sub(pos, pos)
- if char == '{' or char == ',' or char == ';' then
- left = ''
+ if pos then
+ local char = text:sub(pos, pos)
+ if char == '{' or char == ',' or char == ';' then
+ left = ''
+ end
end
end
if left then
@@ -1801,6 +1822,7 @@ local function getluaDocByContain(state, position)
return result
end
+---@return parser.state.err?, parser.object?
local function getluaDocByErr(state, start, position)
local targetError
for _, err in ipairs(state.errs) do
@@ -2008,7 +2030,7 @@ local function tryluaDocByErr(state, position, err, docState, results)
for _, doc in ipairs(vm.getDocSets(state.uri)) do
if doc.type == 'doc.class'
and not used[doc.class[1]]
- and doc.class[1] ~= docState.class[1] then
+ and docState and doc.class[1] ~= docState.class[1] then
used[doc.class[1]] = true
results[#results+1] = {
label = doc.class[1],
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua
index 3c8ed469..1c55f3bf 100644
--- a/script/core/diagnostics/undefined-doc-name.lua
+++ b/script/core/diagnostics/undefined-doc-name.lua
@@ -13,16 +13,6 @@ return function (uri, callback)
return
end
- local function hasNameOfGeneric(name, source)
- if not source.typeGeneric then
- return false
- end
- if not source.typeGeneric[name] then
- return false
- end
- return true
- end
-
guide.eachSource(state.ast.docs, function (source)
if source.type ~= 'doc.extends.name'
and source.type ~= 'doc.type.name' then
@@ -35,8 +25,7 @@ return function (uri, callback)
if name == '...' or name == '_' or name == 'self' then
return
end
- if #vm.getDocSets(uri, name) > 0
- or hasNameOfGeneric(name, source) then
+ if #vm.getDocSets(uri, name) > 0 then
return
end
callback {
diff --git a/script/core/highlight.lua b/script/core/highlight.lua
index 80088680..72214672 100644
--- a/script/core/highlight.lua
+++ b/script/core/highlight.lua
@@ -63,7 +63,7 @@ local function checkInIf(state, source, text, position)
local endA = endB - #'end' + 1
if position >= source.finish - #'end'
and position <= source.finish
- and text:sub(endA, endB) == 'end' then
+ and text and text:sub(endA, endB) == 'end' then
return true
end
-- 检查每个子模块
@@ -83,7 +83,7 @@ local function makeIf(state, source, text, callback)
-- end
local endB = guide.positionToOffset(state, source.finish)
local endA = endB - #'end' + 1
- if text:sub(endA, endB) == 'end' then
+ if text and text:sub(endA, endB) == 'end' then
callback(source.finish - #'end', source.finish)
end
-- 每个子模块
diff --git a/script/fs-utility.lua b/script/fs-utility.lua
index 9a45b1cc..8d2bf319 100644
--- a/script/fs-utility.lua
+++ b/script/fs-utility.lua
@@ -128,6 +128,7 @@ function dfs:__div(filename)
return new
end
+---@package
function dfs:_open(index)
local paths = split(self.path, '[/\\]')
local current = self.files
@@ -147,6 +148,7 @@ function dfs:_open(index)
return current
end
+---@package
function dfs:_filename()
return self.path:match '[^/\\]+$'
end
@@ -291,6 +293,7 @@ local function fsIsDirectory(path, option)
if path.type == 'dummy' then
return path:isDirectory()
end
+ ---@cast path -dummyfs
local status = fs.symlink_status(path):type()
return status == 'directory'
end
@@ -347,6 +350,7 @@ local function fsSave(path, text, option)
return false
end
if path.type == 'dummy' then
+ ---@cast path -fs.path
local dir = path:_open(-2)
if not dir then
option.err[#option.err+1] = '无法打开:' .. path:string()
@@ -385,6 +389,7 @@ local function fsLoad(path, option)
return nil
end
else
+ ---@cast path -dummyfs
local text, err = m.loadFile(path)
if text then
return text
@@ -407,6 +412,7 @@ local function fsCopy(source, target, option)
end
return fsSave(target, sourceText, option)
else
+ ---@cast source -dummyfs
if target.type == 'dummy' then
local sourceText, err = m.loadFile(source)
if not sourceText then
diff --git a/script/gc.lua b/script/gc.lua
index ff22195e..92739585 100644
--- a/script/gc.lua
+++ b/script/gc.lua
@@ -1,12 +1,13 @@
local util = require 'utility'
---@class gc
----@field _list table
+---@field package _list table
local mt = {}
mt.__index = mt
mt.type = 'gc'
mt._removed = false
+---@package
mt._max = 10
local function destroyGCObject(obj)
diff --git a/script/json-edit.lua b/script/json-edit.lua
index 30a55250..efa1216f 100644
--- a/script/json-edit.lua
+++ b/script/json-edit.lua
@@ -384,6 +384,7 @@ end
local JsonEmpty = function () end
+---@return {s: integer, d:integer, f:integer, v: any}
local function decode_ast(str)
if type(str) ~= "string" then
error("expected argument of type string, got " .. type(str))
@@ -607,7 +608,11 @@ function OP.add(str, option, path, value)
end
local ast = decode_ast(str)
if ast.v == JsonEmpty then
- local pathlst = split_path(path)
+ local pathlst, err = split_path(path)
+ if not pathlst then
+ error(err)
+ return
+ end
value = add_prefix(value, pathlst)
return json.beautify(value, option)
end
@@ -674,7 +679,11 @@ function OP.replace(str, option, path, value)
end
local ast = decode_ast(str)
if ast.v == JsonEmpty then
- local pathlst = split_path(path)
+ local pathlst, err = split_path(path)
+ if not pathlst then
+ error(err)
+ return
+ end
value = add_prefix(value, pathlst)
return json.beautify(value, option)
end
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index 5321d9b8..8dd772db 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -239,6 +239,16 @@ local LocalLimit = 200
local parseExp, parseAction
+---@class parser.state.err
+---@field type string
+---@field start? parser.position
+---@field finish? parser.position
+---@field info? table
+---@field fix? table
+---@field version? string[]|string
+---@field level? string | 'Error' | 'Warning'
+
+---@type fun(err:parser.state.err):parser.state.err|nil
local pushError
local function addSpecial(name, obj)
@@ -711,6 +721,7 @@ local function parseLocalAttrs()
return attrs
end
+---@param obj table
local function createLocal(obj, attrs)
obj.type = 'local'
obj.effect = obj.finish
@@ -1709,6 +1720,9 @@ local function parseTable()
end
local function addDummySelf(node, call)
+ if not node then
+ return
+ end
if node.type ~= 'getmethod' then
return
end
@@ -1736,6 +1750,9 @@ local function checkAmbiguityCall(call, parenPos)
return
end
local node = call.node
+ if not node then
+ return
+ end
local nodeRow = guide.rowColOf(node.finish)
local callRow = guide.rowColOf(parenPos)
if nodeRow == callRow then
@@ -2470,7 +2487,10 @@ local function parseExpUnit()
local node = parseName()
if node then
- return parseSimple(resolveName(node), false)
+ local nameNode = resolveName(node)
+ if nameNode then
+ return parseSimple(nameNode, false)
+ end
end
return nil
@@ -3421,6 +3441,7 @@ local function parseFor()
forStateVars = 3
LocalCount = LocalCount + forStateVars
if name then
+ ---@cast name parser.object
local loc = createLocal(name)
loc.parent = action
action.finish = name.finish
@@ -3523,7 +3544,9 @@ local function parseFor()
list.range = lastName and lastName.range or inRight
action.keys = list
for i = 1, #list do
- local loc = createLocal(list[i])
+ local obj = list[i]
+ ---@cast obj parser.object
+ local loc = createLocal(obj)
loc.parent = action
loc.effect = action.finish
end
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index fd779da0..ac7a5ce0 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -74,9 +74,12 @@ local type = type
---@field hasBreak? true
---@field hasExit? true
---@field [integer] parser.object|any
+---@field dot { type: string, start: integer, finish: integer }
+---@field colon { type: string, start: integer, finish: integer }
---@field package _root parser.object
---@field package _eachCache? parser.object[]
---@field package _isGlobal? boolean
+---@field package _typeCache? parser.object[][]
---@class guide
---@field debugMode boolean
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index aec8994b..25ff71c1 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -1620,7 +1620,7 @@ local function trimTailComment(text)
and comment:find '[\'"%]]%s*$' then
local state = compile(comment:gsub('^%s+', ''), 'String')
if state and state.ast then
- comment = state.ast[1]
+ comment = state.ast[1] --[[@as string]]
end
end
return util.trim(comment)
diff --git a/script/plugin.lua b/script/plugin.lua
index b77511ff..35a1da5b 100644
--- a/script/plugin.lua
+++ b/script/plugin.lua
@@ -62,6 +62,7 @@ end
function m.getVmPlugin(uri)
local scp = scope.getScope(uri)
+ ---@type pluginInterfaces
local interfaces = scp:get('pluginInterfaces')
if not interfaces then
return
diff --git a/script/timer.lua b/script/timer.lua
index 09bbce0f..14d33f6a 100644
--- a/script/timer.lua
+++ b/script/timer.lua
@@ -14,6 +14,7 @@ local curIndex = 0
local tarFrame = 0
local fwFrame = 0
local freeQueue = {}
+---@type (timer|false)[][]
local timer = {}
local function allocQueue()
@@ -101,9 +102,10 @@ end
local m = {}
---@class timer
----@field _onTimer? fun(self: timer)
----@field _timeoutFrame integer
----@field _timeout integer
+---@field package _onTimer? fun(self: timer)
+---@field package _timeoutFrame integer
+---@field package _timeout integer
+---@field package _timerCount integer
local mt = {}
mt.__index = mt
mt.type = 'timer'
@@ -119,6 +121,7 @@ function mt:__call()
end
function mt:remove()
+ ---@package
self._removed = true
end
@@ -126,7 +129,9 @@ function mt:pause()
if self._removed or self._pauseRemaining then
return
end
+ ---@package
self._pauseRemaining = getRemaining(self)
+ ---@package
self._running = false
local ti = self._timeoutFrame
local q = timer[ti]
@@ -145,6 +150,7 @@ function mt:resume()
return
end
local timeout = self._pauseRemaining
+ ---@package
self._pauseRemaining = nil
mTimeout(self, timeout)
end
@@ -163,6 +169,7 @@ function mt:restart()
end
end
end
+ ---@package
self._running = false
mTimeout(self, self._timeout)
end
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 2253c83a..fc8f7c52 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1031,6 +1031,7 @@ local function compileForVars(source, target)
return false
end
+---@param func parser.object
---@param source parser.object
local function compileFunctionParam(func, source)
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
@@ -1050,33 +1051,31 @@ local function compileFunctionParam(func, source)
end
end
end
- if func.parent.type == 'local' then
+
+ local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType')
+ if derviationParam and func.parent.type == 'local' and func.parent.ref then
local refs = func.parent.ref
- local findCall
- if refs then
- for i, ref in ipairs(refs) do
- if ref.parent.type == 'call' then
- findCall = ref.parent
- break
- end
+ local found
+ for _, ref in ipairs(refs) do
+ if ref.parent.type ~= 'call' then
+ goto continue
end
- end
- if findCall and findCall.args then
- local index
- for i, arg in ipairs(source.parent) do
- if arg == source then
- index = i
- break
- end
+ local caller = ref.parent
+ if not caller.args then
+ goto continue
end
- if index then
- local callerArg = findCall.args[index]
- if callerArg then
- vm.setNode(source, vm.compileNode(callerArg))
- return true
+ for index, arg in ipairs(source.parent) do
+ if arg == source then
+ local callerArg = caller.args[index]
+ if callerArg then
+ vm.setNode(source, vm.compileNode(callerArg))
+ finded = true
+ end
end
end
+ ::continue::
end
+ return finded
end
end
@@ -1121,24 +1120,9 @@ local function compileLocal(source)
end
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
local func = source.parent.parent
- -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
- local funcNode = vm.compileNode(func)
- local hasDocArg
- for n in funcNode:eachObject() do
- if n.type == 'doc.type.function' then
- for index, arg in ipairs(n.args) do
- if func.args[index] == source then
- local argNode = vm.compileNode(arg)
- for an in argNode:eachObject() do
- if an.type ~= 'doc.generic.name' then
- vm.setNode(source, an)
- end
- end
- hasDocArg = true
- end
- end
- end
- end
+ local vmPlugin = plugin.getVmPlugin(guide.getUri(source))
+ local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source)
+ or compileFunctionParam(func, source)
if not hasDocArg then
vm.setNode(source, vm.declareGlobal('type', 'any'))
end
diff --git a/script/vm/visible.lua b/script/vm/visible.lua
index d13ecf1f..0f486d6b 100644
--- a/script/vm/visible.lua
+++ b/script/vm/visible.lua
@@ -31,6 +31,10 @@ local function getVisibleType(source)
source._visibleType = 'protected'
return 'protected'
end
+ if doc.type == 'doc.package' then
+ source._visibleType = 'package'
+ return 'package'
+ end
end
end
@@ -50,6 +54,12 @@ local function getVisibleType(source)
source._visibleType = 'protected'
return 'protected'
end
+
+ local packageNames = config.get(uri, 'Lua.doc.packageName')
+ if #packageNames > 0 and glob.glob(packageNames)(fieldName) then
+ source._visibleType = 'package'
+ return 'package'
+ end
end
source._visibleType = 'public'
diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua
index 789b5f81..cfdfdc90 100644
--- a/script/workspace/scope.lua
+++ b/script/workspace/scope.lua
@@ -235,11 +235,11 @@ function m.getLinkedScope(uri)
return nil
end
----@param uri uri
+---@param uri? uri
---@return scope
function m.getScope(uri)
- return m.getFolder(uri)
- or m.getLinkedScope(uri)
+ return uri and (m.getFolder(uri)
+ or m.getLinkedScope(uri))
or m.fallback
end