summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
authorfesily <fesil@foxmail.com>2024-01-10 11:03:21 +0800
committerGitHub <noreply@github.com>2024-01-10 11:03:21 +0800
commit1e7bb72ad3ff2b75a1c55ee4bc53004cb7fe30f7 (patch)
treedd94cf09b9e3a675a73a9f1d41248d1538165997 /script
parentbb6e172d6166190bd4edd3bb56230a7d60ebcb93 (diff)
parent37779f9b2493e51e59e1e4366bf7dcb8350e69bd (diff)
downloadlua-language-server-1e7bb72ad3ff2b75a1c55ee4bc53004cb7fe30f7.zip
Merge branch 'LuaLS:master' into plugin-add-OnTransformAst
Diffstat (limited to 'script')
-rw-r--r--script/await.lua5
-rw-r--r--script/brave/work.lua44
-rw-r--r--script/cli/check.lua8
-rw-r--r--script/cli/doc.lua8
-rw-r--r--script/config/template.lua7
-rw-r--r--script/core/code-action.lua54
-rw-r--r--script/core/code-lens.lua3
-rw-r--r--script/core/command/autoRequire.lua11
-rw-r--r--script/core/completion/completion.lua122
-rw-r--r--script/core/completion/keyword.lua35
-rw-r--r--script/core/diagnostics/cast-local-type.lua3
-rw-r--r--script/core/diagnostics/helper/missing-doc-helper.lua8
-rw-r--r--script/core/diagnostics/incomplete-signature-doc.lua28
-rw-r--r--script/core/diagnostics/inject-field.lua147
-rw-r--r--script/core/diagnostics/missing-fields.lua84
-rw-r--r--script/core/diagnostics/missing-local-export-doc.lua6
-rw-r--r--script/core/diagnostics/param-type-mismatch.lua18
-rw-r--r--script/core/diagnostics/undefined-field.lua8
-rw-r--r--script/core/diagnostics/undefined-global.lua46
-rw-r--r--script/core/hover/description.lua87
-rw-r--r--script/core/hover/label.lua10
-rw-r--r--script/core/rename.lua41
-rw-r--r--script/core/semantic-tokens.lua28
-rw-r--r--script/core/signature.lua39
-rw-r--r--script/file-uri.lua5
-rw-r--r--script/files.lua31
-rw-r--r--script/global.d.lua9
-rw-r--r--script/library.lua75
-rw-r--r--script/parser/compile.lua9
-rw-r--r--script/parser/guide.lua13
-rw-r--r--script/parser/luadoc.lua65
-rw-r--r--script/plugins/ffi/c-parser/ctypes.lua5
-rw-r--r--script/plugins/ffi/cdefRerence.lua2
-rw-r--r--script/proto/converter.lua8
-rw-r--r--script/proto/diagnostic.lua2
-rw-r--r--script/proto/proto.lua45
-rw-r--r--script/provider/formatting.lua11
-rw-r--r--script/provider/provider.lua10
-rw-r--r--script/service/net.lua208
-rw-r--r--script/service/service.lua1
-rw-r--r--script/utility.lua137
-rw-r--r--script/vm/compiler.lua139
-rw-r--r--script/vm/doc.lua21
-rw-r--r--script/vm/function.lua10
-rw-r--r--script/vm/global.lua67
-rw-r--r--script/vm/infer.lua19
-rw-r--r--script/vm/node.lua1
-rw-r--r--script/vm/operator.lua8
-rw-r--r--script/vm/sign.lua18
-rw-r--r--script/vm/type.lua30
-rw-r--r--script/vm/visible.lua28
-rw-r--r--script/workspace/workspace.lua12
52 files changed, 1354 insertions, 485 deletions
diff --git a/script/await.lua b/script/await.lua
index fa2aea13..22745570 100644
--- a/script/await.lua
+++ b/script/await.lua
@@ -108,6 +108,11 @@ function m.hasID(id, co)
return m.idMap[id] and m.idMap[id][co] ~= nil
end
+function m.unique(id, callback)
+ m.close(id)
+ m.setID(id, callback)
+end
+
--- 休眠一段时间
---@param time number
---@async
diff --git a/script/brave/work.lua b/script/brave/work.lua
index 9eb756eb..a6a7a41e 100644
--- a/script/brave/work.lua
+++ b/script/brave/work.lua
@@ -15,10 +15,7 @@ end)
brave.on('loadProtoBySocket', function (param)
local jsonrpc = require 'jsonrpc'
- local socket = require 'bee.socket'
- local util = require 'utility'
- local rfd = socket.fd(param.rfd)
- local wfd = socket.fd(param.wfd)
+ local net = require 'service.net'
local buf = ''
---@async
@@ -44,29 +41,24 @@ brave.on('loadProtoBySocket', function (param)
end
end)
+ local lsclient = net.connect('tcp', '127.0.0.1', param.port)
+ local lsmaster = net.connect('unix', param.unixPath)
+
+ assert(lsclient)
+ assert(lsmaster)
+
+ function lsclient:on_data(data)
+ buf = buf .. data
+ coroutine.resume(parser)
+ end
+
+ function lsmaster:on_data(data)
+ lsclient:write(data)
+ net.update()
+ end
+
while true do
- local rd = socket.select({rfd, wfd}, nil, 10)
- if not rd or #rd == 0 then
- goto continue
- end
- if util.arrayHas(rd, wfd) then
- local needSend = wfd:recv()
- if needSend then
- rfd:send(needSend)
- elseif needSend == nil then
- error('socket closed!')
- end
- end
- if util.arrayHas(rd, rfd) then
- local recved = rfd:recv()
- if recved then
- buf = buf .. recved
- elseif recved == nil then
- error('socket closed!')
- end
- coroutine.resume(parser)
- end
- ::continue::
+ net.update(10)
end
end)
diff --git a/script/cli/check.lua b/script/cli/check.lua
index 94f34b2a..4295fa06 100644
--- a/script/cli/check.lua
+++ b/script/cli/check.lua
@@ -9,6 +9,7 @@ local lang = require 'language'
local define = require 'proto.define'
local config = require 'config.config'
local fs = require 'bee.filesystem'
+local provider = require 'provider'
require 'vm'
@@ -52,6 +53,8 @@ lclient():start(function (client)
io.write(lang.script('CLI_CHECK_INITING'))
+ provider.updateConfig(rootUri)
+
ws.awaitReady(rootUri)
local disables = util.arrayToHash(config.get(rootUri, 'Lua.diagnostics.disable'))
@@ -96,7 +99,10 @@ end
if count == 0 then
print(lang.script('CLI_CHECK_SUCCESS'))
else
- local outpath = LOGPATH .. '/check.json'
+ local outpath = CHECK_OUT_PATH
+ if outpath == nil then
+ outpath = LOGPATH .. '/check.json'
+ end
util.saveFile(outpath, jsonb.beautify(results))
print(lang.script('CLI_CHECK_RESULTS', count, outpath))
diff --git a/script/cli/doc.lua b/script/cli/doc.lua
index c643deea..9140a258 100644
--- a/script/cli/doc.lua
+++ b/script/cli/doc.lua
@@ -65,6 +65,7 @@ local function packObject(source, mark)
end
if source.type == 'function.return' then
new['desc'] = source.comment and getDesc(source.comment)
+ new['rawdesc'] = source.comment and getDesc(source.comment, true)
end
if source.type == 'doc.type.table' then
new['fields'] = packObject(source.fields, mark)
@@ -82,6 +83,7 @@ local function packObject(source, mark)
end
if source.bindDocs then
new['desc'] = getDesc(source)
+ new['rawdesc'] = getDesc(source, true)
end
new['view'] = new['view'] or vm.getInfer(source):view(ws.rootUri)
end
@@ -115,6 +117,7 @@ local function collectTypes(global, results)
name = global.name,
type = 'type',
desc = nil,
+ rawdesc = nil,
defines = {},
fields = {},
}
@@ -131,6 +134,7 @@ local function collectTypes(global, results)
extends = getExtends(set),
}
result.desc = result.desc or getDesc(set)
+ result.rawdesc = result.rawdesc or getDesc(set, true)
::CONTINUE::
end
if #result.defines == 0 then
@@ -163,6 +167,7 @@ local function collectTypes(global, results)
field.start = source.start
field.finish = source.finish
field.desc = getDesc(source)
+ field.rawdesc = getDesc(source, true)
field.extends = packObject(source.extends)
return
end
@@ -180,6 +185,7 @@ local function collectTypes(global, results)
field.start = source.start
field.finish = source.finish
field.desc = getDesc(source)
+ field.rawdesc = getDesc(source, true)
field.extends = packObject(source.value)
return
end
@@ -199,6 +205,7 @@ local function collectTypes(global, results)
field.start = source.start
field.finish = source.finish
field.desc = getDesc(source)
+ field.rawdesc = getDesc(source, true)
field.extends = packObject(source.value)
return
end
@@ -237,6 +244,7 @@ local function collectVars(global, results)
extends = packObject(set.value),
}
result.desc = result.desc or getDesc(set)
+ result.rawdesc = result.rawdesc or getDesc(set, true)
end
end
if #result.defines == 0 then
diff --git a/script/config/template.lua b/script/config/template.lua
index 436f5e1a..ba0fd503 100644
--- a/script/config/template.lua
+++ b/script/config/template.lua
@@ -317,7 +317,12 @@ local template = {
['Lua.workspace.maxPreload'] = Type.Integer >> 5000,
['Lua.workspace.preloadFileSize'] = Type.Integer >> 500,
['Lua.workspace.library'] = Type.Array(Type.String),
- ['Lua.workspace.checkThirdParty'] = Type.Boolean >> true,
+ ['Lua.workspace.checkThirdParty'] = Type.Or(Type.String >> 'Ask' << {
+ 'Ask',
+ 'Apply',
+ 'ApplyInMemory',
+ 'Disable',
+ }, Type.Boolean),
['Lua.workspace.userThirdParty'] = Type.Array(Type.String),
['Lua.completion.enable'] = Type.Boolean >> true,
['Lua.completion.callSnippet'] = Type.String >> 'Disable' << {
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
index e226e03b..720cd4c4 100644
--- a/script/core/code-action.lua
+++ b/script/core/code-action.lua
@@ -4,6 +4,11 @@ local util = require 'utility'
local sp = require 'bee.subprocess'
local guide = require "parser.guide"
local converter = require 'proto.converter'
+local autoreq = require 'core.completion.auto-require'
+local rpath = require 'workspace.require-path'
+local furi = require 'file-uri'
+local undefined = require 'core.diagnostics.undefined-global'
+local vm = require 'vm'
---@param uri uri
---@param row integer
@@ -676,6 +681,54 @@ local function checkJsonToLua(results, uri, start, finish)
}
end
+local function findRequireTargets(visiblePaths)
+ local targets = {}
+ for _, visible in ipairs(visiblePaths) do
+ targets[#targets+1] = visible.name
+ end
+ return targets
+end
+
+local function checkMissingRequire(results, uri, start, finish)
+ local state = files.getState(uri)
+ local text = files.getText(uri)
+ if not state or not text then
+ return
+ end
+
+ local function addRequires(global, endpos)
+ autoreq.check(state, global, endpos, function(moduleFile, stemname, targetSource)
+ local visiblePaths = rpath.getVisiblePath(uri, furi.decode(moduleFile))
+ if not visiblePaths or #visiblePaths == 0 then return end
+
+ for _, target in ipairs(findRequireTargets(visiblePaths)) do
+ results[#results+1] = {
+ title = lang.script('ACTION_AUTOREQUIRE', target, global),
+ kind = 'refactor.rewrite',
+ command = {
+ title = 'autoRequire',
+ command = 'lua.autoRequire',
+ arguments = {
+ {
+ uri = guide.getUri(state.ast),
+ target = moduleFile,
+ name = global,
+ requireName = target
+ },
+ },
+ }
+ }
+ end
+ end)
+ end
+
+ guide.eachSourceBetween(state.ast, start, finish, function (source)
+ if vm.isUndefinedGlobal(source) then
+ addRequires(source[1], source.finish)
+ end
+ end)
+end
+
return function (uri, start, finish, diagnostics)
local ast = files.getState(uri)
if not ast then
@@ -688,6 +741,7 @@ return function (uri, start, finish, diagnostics)
checkSwapParams(results, uri, start, finish)
--checkExtractAsFunction(results, uri, start, finish)
checkJsonToLua(results, uri, start, finish)
+ checkMissingRequire(results, uri, start, finish)
return results
end
diff --git a/script/core/code-lens.lua b/script/core/code-lens.lua
index ed95ea6a..bc39ec86 100644
--- a/script/core/code-lens.lua
+++ b/script/core/code-lens.lua
@@ -6,7 +6,7 @@ local getRef = require 'core.reference'
local lang = require 'language'
---@class parser.state
----@field package _codeLens codeLens
+---@field package _codeLens? codeLens
---@class codeLens.resolving
---@field mode 'reference'
@@ -14,7 +14,6 @@ local lang = require 'language'
---@class codeLens.result
---@field position integer
----@field uri uri
---@field id integer
---@class codeLens
diff --git a/script/core/command/autoRequire.lua b/script/core/command/autoRequire.lua
index 32911d92..a96cc918 100644
--- a/script/core/command/autoRequire.lua
+++ b/script/core/command/autoRequire.lua
@@ -135,6 +135,7 @@ return function (data)
local uri = data.uri
local target = data.target
local name = data.name
+ local requireName = data.requireName
local state = files.getState(uri)
if not state then
return
@@ -149,11 +150,13 @@ return function (data)
return #a.name < #b.name
end)
- local result = askAutoRequire(uri, visiblePaths)
- if not result then
- return
+ if not requireName then
+ requireName = askAutoRequire(uri, visiblePaths)
+ if not requireName then
+ return
+ end
end
local offset, fmt = findInsertRow(uri)
- applyAutoRequire(uri, offset, name, result, fmt)
+ applyAutoRequire(uri, offset, name, requireName, fmt)
end
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua
index 0ec503de..acb3adbe 100644
--- a/script/core/completion/completion.lua
+++ b/script/core/completion/completion.lua
@@ -504,7 +504,8 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent
local kind = define.CompletionItemKind.Field
if (value.type == 'function' and not vm.isVarargFunctionWithOverloads(value))
or value.type == 'doc.type.function' then
- if oop then
+ local isMethod = value.parent.type == 'setmethod'
+ if isMethod then
kind = define.CompletionItemKind.Method
else
kind = define.CompletionItemKind.Function
@@ -512,6 +513,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent
buildFunction(results, src, value, oop, {
label = name,
kind = kind,
+ isMethod = isMethod,
match = name:match '^[^(]+',
insertText = name:match '^[^(]+',
deprecated = vm.getDeprecated(src) and true or nil,
@@ -626,12 +628,43 @@ local function checkFieldOfRefs(refs, state, word, startPos, position, parent, o
end
::CONTINUE::
end
+
+ local fieldResults = {}
for name, src in util.sortPairs(fields) do
if src then
- checkFieldThen(state, name, src, word, startPos, position, parent, oop, results)
+ checkFieldThen(state, name, src, word, startPos, position, parent, oop, fieldResults)
await.delay()
end
end
+
+ local scoreMap = {}
+ for i, res in ipairs(fieldResults) do
+ scoreMap[res] = i
+ end
+ table.sort(fieldResults, function (a, b)
+ local score1 = scoreMap[a]
+ local score2 = scoreMap[b]
+ if oop then
+ if not a.isMethod then
+ score1 = score1 + 10000
+ end
+ if not b.isMethod then
+ score2 = score2 + 10000
+ end
+ else
+ if a.isMethod then
+ score1 = score1 + 10000
+ end
+ if b.isMethod then
+ score2 = score2 + 10000
+ end
+ end
+ return score1 < score2
+ end)
+
+ for _, res in ipairs(fieldResults) do
+ results[#results+1] = res
+ end
end
---@async
@@ -1218,6 +1251,46 @@ local function insertDocEnum(state, pos, doc, enums)
return enums
end
+---@param state parser.state
+---@param pos integer
+---@param doc vm.node.object
+---@param enums table[]
+---@return table[]?
+local function insertDocEnumKey(state, pos, doc, enums)
+ local tbl = doc.bindSource
+ if not tbl then
+ return nil
+ end
+ local keyEnums = {}
+ for _, field in ipairs(tbl) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ if not field.value then
+ goto CONTINUE
+ end
+ local key = guide.getKeyName(field)
+ if not key then
+ goto CONTINUE
+ end
+ enums[#enums+1] = {
+ label = ('%q'):format(key),
+ kind = define.CompletionItemKind.EnumMember,
+ id = stack(field, function (newField) ---@async
+ return {
+ detail = buildDetail(newField),
+ description = buildDesc(newField),
+ }
+ end),
+ }
+ ::CONTINUE::
+ end
+ end
+ for _, enum in ipairs(keyEnums) do
+ enums[#enums+1] = enum
+ end
+ return enums
+end
+
local function buildInsertDocFunction(doc)
local args = {}
for i, arg in ipairs(doc.args) do
@@ -1283,7 +1356,11 @@ local function insertEnum(state, pos, src, enums, isInArray, mark)
elseif src.type == 'global' and src.cate == 'type' then
for _, set in ipairs(src:getSets(state.uri)) do
if set.type == 'doc.enum' then
- insertDocEnum(state, pos, set, enums)
+ if vm.docHasAttr(set, 'key') then
+ insertDocEnumKey(state, pos, set, enums)
+ else
+ insertDocEnum(state, pos, set, enums)
+ end
end
end
end
@@ -1539,14 +1616,13 @@ local function checkTableLiteralField(state, position, tbl, fields, results)
end
end
if left then
- local hasResult = false
+ local fieldResults = {}
for _, field in ipairs(fields) do
local name = guide.getKeyName(field)
if name
and not mark[name]
and matchKey(left, tostring(name)) then
- hasResult = true
- results[#results+1] = {
+ local res = {
label = guide.getKeyName(field),
kind = define.CompletionItemKind.Property,
id = stack(field, function (newField) ---@async
@@ -1556,9 +1632,20 @@ local function checkTableLiteralField(state, position, tbl, fields, results)
}
end),
}
+ if field.optional
+ or vm.compileNode(field):isNullable() then
+ res.insertText = res.label
+ res.label = res.label.. '?'
+ end
+ fieldResults[#fieldResults+1] = res
end
end
- return hasResult
+ util.sortByScore(fieldResults, {
+ function (r) return r.insertText and 0 or 1 end,
+ util.sortCallbackOfIndex(fieldResults),
+ })
+ util.arrayMerge(results, fieldResults)
+ return #fieldResults > 0
end
end
@@ -1571,7 +1658,8 @@ local function tryCallArg(state, position, results)
if arg and arg.type == 'function' then
return
end
- local node = vm.compileCallArg({ type = 'dummyarg' }, call, argIndex)
+ ---@diagnostic disable-next-line: missing-fields
+ local node = vm.compileCallArg({ type = 'dummyarg', uri = state.uri }, call, argIndex)
if not node then
return
end
@@ -2070,7 +2158,7 @@ local function tryluaDocByErr(state, position, err, docState, results)
end
end
-local function buildluaDocOfFunction(func)
+local function buildluaDocOfFunction(func, pad)
local index = 1
local buf = {}
buf[#buf+1] = '${1:comment}'
@@ -2094,7 +2182,8 @@ local function buildluaDocOfFunction(func)
local funcArg = func.args[n]
if funcArg[1] and funcArg.type ~= 'self' then
index = index + 1
- buf[#buf+1] = ('---@param %s ${%d:%s}'):format(
+ buf[#buf+1] = ('---%s@param %s ${%d:%s}'):format(
+ pad and ' ' or '',
funcArg[1],
index,
arg
@@ -2103,7 +2192,8 @@ local function buildluaDocOfFunction(func)
end
for _, rtn in ipairs(returns) do
index = index + 1
- buf[#buf+1] = ('---@return ${%d:%s}'):format(
+ buf[#buf+1] = ('---%s@return ${%d:%s}'):format(
+ pad and ' ' or '',
index,
rtn
)
@@ -2112,7 +2202,7 @@ local function buildluaDocOfFunction(func)
return insertText
end
-local function tryluaDocOfFunction(doc, results)
+local function tryluaDocOfFunction(doc, results, pad)
if not doc.bindSource then
return
end
@@ -2134,7 +2224,7 @@ local function tryluaDocOfFunction(doc, results)
end
end
end
- local insertText = buildluaDocOfFunction(func)
+ local insertText = buildluaDocOfFunction(func, pad)
results[#results+1] = {
label = '@param;@return',
kind = define.CompletionItemKind.Snippet,
@@ -2152,9 +2242,9 @@ local function tryLuaDoc(state, position, results)
end
if doc.type == 'doc.comment' then
local line = doc.originalComment.text
- -- 尝试 ---$
- if line == '-' then
- tryluaDocOfFunction(doc, results)
+ -- 尝试 '---$' or '--- $'
+ if line == '-' or line == '- ' then
+ tryluaDocOfFunction(doc, results, line == '- ')
return
end
-- 尝试 ---@$
diff --git a/script/core/completion/keyword.lua b/script/core/completion/keyword.lua
index e6f50242..aa0e2148 100644
--- a/script/core/completion/keyword.lua
+++ b/script/core/completion/keyword.lua
@@ -3,6 +3,7 @@ local files = require 'files'
local guide = require 'parser.guide'
local config = require 'config'
local util = require 'utility'
+local lookback = require 'core.look-backward'
local keyWordMap = {
{ 'do', function(info, results)
@@ -372,17 +373,35 @@ end"
else
newText = '::continue::'
end
+ local additional = {}
+
+ local word = lookback.findWord(info.state.lua, guide.positionToOffset(info.state, info.start) - 1)
+ if word ~= 'goto' then
+ additional[#additional+1] = {
+ start = info.start,
+ finish = info.start,
+ newText = 'goto ',
+ }
+ end
+
+ local hasContinue = guide.eachSourceType(mostInsideBlock, 'label', function (src)
+ if src[1] == 'continue' then
+ return true
+ end
+ end)
+
+ if not hasContinue then
+ additional[#additional+1] = {
+ start = endPos,
+ finish = endPos,
+ newText = newText,
+ }
+ end
results[#results+1] = {
label = 'goto continue ..',
kind = define.CompletionItemKind.Snippet,
- insertText = "goto continue",
- additionalTextEdits = {
- {
- start = endPos,
- finish = endPos,
- newText = newText,
- }
- }
+ insertText = "continue",
+ additionalTextEdits = additional,
}
return true
end }
diff --git a/script/core/diagnostics/cast-local-type.lua b/script/core/diagnostics/cast-local-type.lua
index 1998b915..26445374 100644
--- a/script/core/diagnostics/cast-local-type.lua
+++ b/script/core/diagnostics/cast-local-type.lua
@@ -16,6 +16,9 @@ return function (uri, callback)
if not loc.ref then
return
end
+ if loc[1] == '_' then
+ return
+ end
await.delay()
local locNode = vm.compileNode(loc)
if not locNode.hasDefined then
diff --git a/script/core/diagnostics/helper/missing-doc-helper.lua b/script/core/diagnostics/helper/missing-doc-helper.lua
index 84221693..116173f2 100644
--- a/script/core/diagnostics/helper/missing-doc-helper.lua
+++ b/script/core/diagnostics/helper/missing-doc-helper.lua
@@ -38,8 +38,9 @@ end
local function checkFunction(source, callback, commentId, paramId, returnId)
local functionName = source.parent[1]
+ local argCount = source.args and #source.args or 0
- if #source.args == 0 and not source.returns and not source.bindDocs then
+ if argCount == 0 and not source.returns and not source.bindDocs then
callback {
start = source.start,
finish = source.finish,
@@ -47,10 +48,11 @@ local function checkFunction(source, callback, commentId, paramId, returnId)
}
end
- if #source.args > 0 then
+ if argCount > 0 then
for _, arg in ipairs(source.args) do
local argName = arg[1]
- if argName ~= 'self' then
+ if argName ~= 'self'
+ and argName ~= '_' then
if not findParam(source.bindDocs, argName) then
callback {
start = arg.start,
diff --git a/script/core/diagnostics/incomplete-signature-doc.lua b/script/core/diagnostics/incomplete-signature-doc.lua
index 91f2db74..1ffbb77a 100644
--- a/script/core/diagnostics/incomplete-signature-doc.lua
+++ b/script/core/diagnostics/incomplete-signature-doc.lua
@@ -38,6 +38,19 @@ local function findReturn(docs, index)
return false
end
+--- check if there's any signature doc (@param or @return), or just comments, @async, ...
+local function findSignatureDoc(docs)
+ if not docs then
+ return false
+ end
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.return' or doc.type == 'doc.param' then
+ return true
+ end
+ end
+ return false
+end
+
---@async
return function (uri, callback)
local state = files.getState(uri)
@@ -57,17 +70,22 @@ return function (uri, callback)
return
end
- local functionName = source.parent[1]
+ --- don't apply rule if there is no @param or @return annotation yet
+ --- so comments and @async can be applied without the need for a full documentation
+ if(not findSignatureDoc(source.bindDocs)) then
+ return
+ end
- if #source.args > 0 then
+ if source.args and #source.args > 0 then
for _, arg in ipairs(source.args) do
local argName = arg[1]
- if argName ~= 'self' then
+ if argName ~= 'self'
+ and argName ~= '_' then
if not findParam(source.bindDocs, argName) then
callback {
start = arg.start,
finish = arg.finish,
- message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_PARAM', argName, functionName),
+ message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_PARAM', argName),
}
end
end
@@ -81,7 +99,7 @@ return function (uri, callback)
callback {
start = expr.start,
finish = expr.finish,
- message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_RETURN', index, functionName),
+ message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_RETURN', index),
}
end
end
diff --git a/script/core/diagnostics/inject-field.lua b/script/core/diagnostics/inject-field.lua
new file mode 100644
index 00000000..2866eef8
--- /dev/null
+++ b/script/core/diagnostics/inject-field.lua
@@ -0,0 +1,147 @@
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local await = require 'await'
+local hname = require 'core.hover.name'
+
+local skipCheckClass = {
+ ['unknown'] = true,
+ ['any'] = true,
+ ['table'] = true,
+}
+
+---@async
+return function (uri, callback)
+ local ast = files.getState(uri)
+ if not ast then
+ return
+ end
+
+ ---@async
+ local function checkInjectField(src)
+ await.delay()
+
+ local node = src.node
+ if not node then
+ return
+ end
+ local ok
+ for view in vm.getInfer(node):eachView(uri) do
+ if skipCheckClass[view] then
+ return
+ end
+ ok = true
+ end
+ if not ok then
+ return
+ end
+
+ local isExact
+ local class = vm.getDefinedClass(uri, node)
+ if class then
+ for _, doc in ipairs(class:getSets(uri)) do
+ if vm.docHasAttr(doc, 'exact') then
+ isExact = true
+ break
+ end
+ end
+ if not isExact then
+ return
+ end
+ if src.type == 'setmethod'
+ and not guide.getSelfNode(node) then
+ return
+ end
+ end
+
+ for _, def in ipairs(vm.getDefs(src)) do
+ local dnode = def.node
+ if dnode
+ and not isExact
+ and vm.getDefinedClass(uri, dnode) then
+ return
+ end
+ if def.type == 'doc.type.field' then
+ return
+ end
+ if def.type == 'doc.field' then
+ return
+ end
+ end
+
+ local howToFix = ''
+ if not isExact then
+ howToFix = lang.script('DIAG_INJECT_FIELD_FIX_CLASS', {
+ node = hname(node),
+ fix = '---@class',
+ })
+ for _, ndef in ipairs(vm.getDefs(node)) do
+ if ndef.type == 'doc.type.table' then
+ howToFix = lang.script('DIAG_INJECT_FIELD_FIX_TABLE', {
+ fix = '[any]: any',
+ })
+ break
+ end
+ end
+ end
+
+ local message = lang.script('DIAG_INJECT_FIELD', {
+ class = vm.getInfer(node):view(uri),
+ field = guide.getKeyName(src),
+ fix = howToFix,
+ })
+ if src.type == 'setfield' and src.field then
+ callback {
+ start = src.field.start,
+ finish = src.field.finish,
+ message = message,
+ }
+ elseif src.type == 'setmethod' and src.method then
+ callback {
+ start = src.method.start,
+ finish = src.method.finish,
+ message = message,
+ }
+ end
+ end
+ guide.eachSourceType(ast.ast, 'setfield', checkInjectField)
+ guide.eachSourceType(ast.ast, 'setmethod', checkInjectField)
+
+ ---@async
+ local function checkExtraTableField(src)
+ await.delay()
+
+ if not src.bindSource then
+ return
+ end
+ if not vm.docHasAttr(src, 'exact') then
+ return
+ end
+ local value = src.bindSource.value
+ if not value or value.type ~= 'table' then
+ return
+ end
+ for _, field in ipairs(value) do
+ local defs = vm.getDefs(field)
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.field' then
+ goto nextField
+ end
+ end
+ local message = lang.script('DIAG_INJECT_FIELD', {
+ class = vm.getInfer(src):view(uri),
+ field = guide.getKeyName(src),
+ fix = '',
+ })
+ callback {
+ start = field.start,
+ finish = field.finish,
+ message = message,
+ }
+ ::nextField::
+ end
+ end
+
+ guide.eachSourceType(ast.ast, 'doc.class', checkExtraTableField)
+end
diff --git a/script/core/diagnostics/missing-fields.lua b/script/core/diagnostics/missing-fields.lua
new file mode 100644
index 00000000..210920fd
--- /dev/null
+++ b/script/core/diagnostics/missing-fields.lua
@@ -0,0 +1,84 @@
+local vm = require 'vm'
+local files = require 'files'
+local guide = require 'parser.guide'
+local await = require 'await'
+local lang = require 'language'
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@async
+ guide.eachSourceType(state.ast, 'table', function (src)
+ await.delay()
+
+ local defs = vm.getDefs(src)
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.class' and def.bindSource then
+ if guide.isInRange(def.bindSource, src.start) then
+ return
+ end
+ end
+ if def.type == 'doc.type.array'
+ or def.type == 'doc.type.table' then
+ return
+ end
+ end
+ local warnings = {}
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.class' then
+ if not def.fields then
+ return
+ end
+
+ local requiresKeys = {}
+ for _, field in ipairs(def.fields) do
+ if not field.optional
+ and not vm.compileNode(field):isNullable() then
+ local key = vm.getKeyName(field)
+ if key and not requiresKeys[key] then
+ requiresKeys[key] = true
+ requiresKeys[#requiresKeys+1] = key
+ end
+ end
+ end
+
+ if #requiresKeys == 0 then
+ return
+ end
+ local myKeys = {}
+ for _, field in ipairs(src) do
+ local key = vm.getKeyName(field)
+ if key then
+ myKeys[key] = true
+ end
+ end
+
+ local missedKeys = {}
+ for _, key in ipairs(requiresKeys) do
+ if not myKeys[key] then
+ missedKeys[#missedKeys+1] = ('`%s`'):format(key)
+ end
+ end
+
+ if #missedKeys == 0 then
+ return
+ end
+
+ warnings[#warnings+1] = lang.script('DIAG_MISSING_FIELDS', def.class[1], table.concat(missedKeys, ', '))
+ end
+ end
+
+ if #warnings == 0 then
+ return
+ end
+ callback {
+ start = src.start,
+ finish = src.finish,
+ message = table.concat(warnings, '\n')
+ }
+ end)
+end
diff --git a/script/core/diagnostics/missing-local-export-doc.lua b/script/core/diagnostics/missing-local-export-doc.lua
index 5825c115..da413961 100644
--- a/script/core/diagnostics/missing-local-export-doc.lua
+++ b/script/core/diagnostics/missing-local-export-doc.lua
@@ -10,7 +10,13 @@ local function findSetField(ast, name, callback)
await.delay()
if source.node[1] == name then
local funcPtr = source.value.node
+ if not funcPtr then
+ return
+ end
local func = funcPtr.value
+ if not func then
+ return
+ end
if funcPtr.type == 'local' and func.type == 'function' then
helper.CheckFunction(func, callback, 'DIAG_MISSING_LOCAL_EXPORT_DOC_COMMENT', 'DIAG_MISSING_LOCAL_EXPORT_DOC_PARAM', 'DIAG_MISSING_LOCAL_EXPORT_DOC_RETURN')
end
diff --git a/script/core/diagnostics/param-type-mismatch.lua b/script/core/diagnostics/param-type-mismatch.lua
index da39c5e1..acbf9c8c 100644
--- a/script/core/diagnostics/param-type-mismatch.lua
+++ b/script/core/diagnostics/param-type-mismatch.lua
@@ -32,26 +32,28 @@ end
---@param funcNode vm.node
---@param i integer
+---@param uri uri
---@return vm.node?
-local function getDefNode(funcNode, i)
+local function getDefNode(funcNode, i, uri)
local defNode = vm.createNode()
- for f in funcNode:eachObject() do
- if f.type == 'function'
- or f.type == 'doc.type.function' then
- local param = f.args and f.args[i]
+ for src in funcNode:eachObject() do
+ if src.type == 'function'
+ or src.type == 'doc.type.function' then
+ local param = src.args and src.args[i]
if param then
defNode:merge(vm.compileNode(param))
if param[1] == '...' then
defNode:addOptional()
end
-
- expandGenerics(defNode)
end
end
end
if defNode:isEmpty() then
return nil
end
+
+ expandGenerics(defNode)
+
return defNode
end
@@ -91,7 +93,7 @@ return function (uri, callback)
if not refNode then
goto CONTINUE
end
- local defNode = getDefNode(funcNode, i)
+ local defNode = getDefNode(funcNode, i, uri)
if not defNode then
goto CONTINUE
end
diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua
index a83241f5..4fd55966 100644
--- a/script/core/diagnostics/undefined-field.lua
+++ b/script/core/diagnostics/undefined-field.lua
@@ -8,13 +8,6 @@ local skipCheckClass = {
['unknown'] = true,
['any'] = true,
['table'] = true,
- ['nil'] = true,
- ['number'] = true,
- ['integer'] = true,
- ['boolean'] = true,
- ['function'] = true,
- ['userdata'] = true,
- ['lightuserdata'] = true,
}
---@async
@@ -61,5 +54,4 @@ return function (uri, callback)
end
guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField)
guide.eachSourceType(ast.ast, 'getmethod', checkUndefinedField)
- guide.eachSourceType(ast.ast, 'getindex', checkUndefinedField)
end
diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua
index 179c9204..d9d94959 100644
--- a/script/core/diagnostics/undefined-global.lua
+++ b/script/core/diagnostics/undefined-global.lua
@@ -20,41 +20,21 @@ return function (uri, callback)
return
end
- local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
- local rspecial = config.get(uri, 'Lua.runtime.special')
- local cache = {}
-
-- 遍历全局变量,检查所有没有 set 模式的全局变量
guide.eachSourceType(state.ast, 'getglobal', function (src) ---@async
- local key = src[1]
- if not key then
- return
- end
- if dglobals[key] then
- return
- end
- if rspecial[key] then
- return
- end
- local node = src.node
- if node.tag ~= '_ENV' then
- return
- end
- if cache[key] == nil then
- await.delay()
- cache[key] = vm.hasGlobalSets(uri, 'variable', key)
- end
- if cache[key] then
- return
- end
- local message = lang.script('DIAG_UNDEF_GLOBAL', key)
- if requireLike[key:lower()] then
- message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key))
+ if vm.isUndefinedGlobal(src) then
+ local key = src[1]
+ local message = lang.script('DIAG_UNDEF_GLOBAL', key)
+ if requireLike[key:lower()] then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key))
+ end
+
+ callback {
+ start = src.start,
+ finish = src.finish,
+ message = message,
+ undefinedGlobal = src[1]
+ }
end
- callback {
- start = src.start,
- finish = src.finish,
- message = message,
- }
end)
end
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
index f5890b21..75189b06 100644
--- a/script/core/hover/description.lua
+++ b/script/core/hover/description.lua
@@ -336,7 +336,7 @@ local function tryDocFieldComment(source)
end
end
-local function getFunctionComment(source)
+local function getFunctionComment(source, raw)
local docGroup = source.bindDocs
if not docGroup then
return
@@ -356,14 +356,14 @@ local function getFunctionComment(source)
if doc.type == 'doc.comment' then
local comment = normalizeComment(doc.comment.text, uri)
md:add('md', comment)
- elseif doc.type == 'doc.param' then
+ elseif doc.type == 'doc.param' and not raw then
if doc.comment then
md:add('md', ('@*param* `%s` — %s'):format(
doc.param[1],
doc.comment.text
))
end
- elseif doc.type == 'doc.return' then
+ elseif doc.type == 'doc.return' and not raw then
if hasReturnComment then
local name = {}
for _, rtn in ipairs(doc.returns) do
@@ -401,13 +401,13 @@ local function getFunctionComment(source)
end
---@async
-local function tryDocComment(source)
+local function tryDocComment(source, raw)
local md = markdown()
if source.value and source.value.type == 'function' then
source = source.value
end
if source.type == 'function' then
- local comment = getFunctionComment(source)
+ local comment = getFunctionComment(source, raw)
md:add('md', comment)
source = source.parent
end
@@ -429,7 +429,7 @@ local function tryDocComment(source)
end
---@async
-local function tryDocOverloadToComment(source)
+local function tryDocOverloadToComment(source, raw)
if source.type ~= 'doc.type.function' then
return
end
@@ -438,7 +438,7 @@ local function tryDocOverloadToComment(source)
or not doc.bindSource then
return
end
- local md = tryDocComment(doc.bindSource)
+ local md = tryDocComment(doc.bindSource, raw)
if md then
return md
end
@@ -477,38 +477,59 @@ local function tryDocEnum(source)
if not tbl then
return
end
- local md = markdown()
- md:add('lua', '{')
- for _, field in ipairs(tbl) do
- if field.type == 'tablefield'
- or field.type == 'tableindex' then
- if not field.value then
- goto CONTINUE
- end
- local key = guide.getKeyName(field)
- if not key then
- goto CONTINUE
- end
- if field.value.type == 'integer'
- or field.value.type == 'string' then
- md:add('lua', (' %s: %s = %s,'):format(key, field.value.type, field.value[1]))
+ if vm.docHasAttr(source, 'key') then
+ local md = markdown()
+ local keys = {}
+ for _, field in ipairs(tbl) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ if not field.value then
+ goto CONTINUE
+ end
+ local key = guide.getKeyName(field)
+ if not key then
+ goto CONTINUE
+ end
+ keys[#keys+1] = ('%q'):format(key)
+ ::CONTINUE::
end
- if field.value.type == 'binary'
- or field.value.type == 'unary' then
- local number = vm.getNumber(field.value)
- if number then
- md:add('lua', (' %s: %s = %s,'):format(key, math.tointeger(number) and 'integer' or 'number', number))
+ end
+ md:add('lua', table.concat(keys, ' | '))
+ return md:string()
+ else
+ local md = markdown()
+ md:add('lua', '{')
+ for _, field in ipairs(tbl) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ if not field.value then
+ goto CONTINUE
+ end
+ local key = guide.getKeyName(field)
+ if not key then
+ goto CONTINUE
+ end
+ if field.value.type == 'integer'
+ or field.value.type == 'string' then
+ md:add('lua', (' %s: %s = %s,'):format(key, field.value.type, field.value[1]))
+ end
+ if field.value.type == 'binary'
+ or field.value.type == 'unary' then
+ local number = vm.getNumber(field.value)
+ if number then
+ md:add('lua', (' %s: %s = %s,'):format(key, math.tointeger(number) and 'integer' or 'number', number))
+ end
end
+ ::CONTINUE::
end
- ::CONTINUE::
end
+ md:add('lua', '}')
+ return md:string()
end
- md:add('lua', '}')
- return md:string()
end
---@async
-return function (source)
+return function (source, raw)
if source.type == 'string' then
return asString(source)
end
@@ -518,10 +539,10 @@ return function (source)
if source.type == 'field' then
source = source.parent
end
- return tryDocOverloadToComment(source)
+ return tryDocOverloadToComment(source, raw)
or tryDocFieldComment(source)
or tyrDocParamComment(source)
- or tryDocComment(source)
+ or tryDocComment(source, raw)
or tryDocClassComment(source)
or tryDocModule(source)
or tryDocEnum(source)
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
index 6ce4dde9..62e51927 100644
--- a/script/core/hover/label.lua
+++ b/script/core/hover/label.lua
@@ -134,7 +134,7 @@ local function asField(source)
end
local function asDocFieldName(source)
- local name = source.field[1]
+ local name = vm.viewKey(source, guide.getUri(source)) or '?'
local class
for _, doc in ipairs(source.bindGroup) do
if doc.type == 'doc.class' then
@@ -143,10 +143,12 @@ local function asDocFieldName(source)
end
end
local view = vm.getInfer(source.extends):view(guide.getUri(source))
- if not class then
- return ('(field) ?.%s: %s'):format(name, view)
+ local className = class and class.class[1] or '?'
+ if name:match(guide.namePatternFull) then
+ return ('(field) %s.%s: %s'):format(className, name, view)
+ else
+ return ('(field) %s%s: %s'):format(className, name, view)
end
- return ('(field) %s.%s: %s'):format(class.class[1], name, view)
end
local function asString(source)
diff --git a/script/core/rename.lua b/script/core/rename.lua
index 507def20..cc5d37f3 100644
--- a/script/core/rename.lua
+++ b/script/core/rename.lua
@@ -3,42 +3,59 @@ local vm = require 'vm'
local util = require 'utility'
local findSource = require 'core.find-source'
local guide = require 'parser.guide'
+local config = require 'config'
local Forcing
+---@param str string
+---@return string
local function trim(str)
return str:match '^%s*(%S+)%s*$'
end
-local function isValidName(str)
+---@param uri uri
+---@param str string
+---@return boolean
+local function isValidName(uri, str)
if not str then
return false
end
- return str:match '^[%a_][%w_]*$'
+ local allowUnicode = config.get(uri, 'Lua.runtime.unicodeName')
+ if allowUnicode then
+ return str:match '^[%a_\x80-\xff][%w_\x80-\xff]*$'
+ else
+ return str:match '^[%a_][%w_]*$'
+ end
end
-local function isValidGlobal(str)
+---@param uri uri
+---@param str string
+---@return boolean
+local function isValidGlobal(uri, str)
if not str then
return false
end
for s in str:gmatch '[^%.]*' do
- if not isValidName(trim(s)) then
+ if not isValidName(uri, trim(s)) then
return false
end
end
return true
end
-local function isValidFunctionName(str)
- if isValidGlobal(str) then
+---@param uri uri
+---@param str string
+---@return boolean
+local function isValidFunctionName(uri, str)
+ if isValidGlobal(uri, str) then
return true
end
local offset = str:find(':', 1, true)
if not offset then
return false
end
- return isValidGlobal(trim(str:sub(1, offset-1)))
- and isValidName(trim(str:sub(offset+1)))
+ return isValidGlobal(uri, trim(str:sub(1, offset-1)))
+ and isValidName(uri, trim(str:sub(offset+1)))
end
local function isFunctionGlobalName(source)
@@ -54,7 +71,7 @@ local function isFunctionGlobalName(source)
end
local function renameLocal(source, newname, callback)
- if isValidName(newname) then
+ if isValidName(guide.getUri(source), newname) then
callback(source, source.start, source.finish, newname)
return
end
@@ -62,7 +79,7 @@ local function renameLocal(source, newname, callback)
end
local function renameField(source, newname, callback)
- if isValidName(newname) then
+ if isValidName(guide.getUri(source), newname) then
callback(source, source.start, source.finish, newname)
return true
end
@@ -108,11 +125,11 @@ local function renameField(source, newname, callback)
end
local function renameGlobal(source, newname, callback)
- if isValidGlobal(newname) then
+ if isValidGlobal(guide.getUri(source), newname) then
callback(source, source.start, source.finish, newname)
return true
end
- if isValidFunctionName(newname) then
+ if isValidFunctionName(guide.getUri(source), newname) then
callback(source, source.start, source.finish, newname)
return true
end
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index 4d191b69..4e1d8e00 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -138,12 +138,20 @@ local Care = util.switch()
local uri = guide.getUri(loc)
-- 1. 值为函数的局部变量 | Local variable whose value is a function
if vm.getInfer(source):hasFunction(uri) then
- results[#results+1] = {
- start = source.start,
- finish = source.finish,
- type = define.TokenTypes['function'],
- modifieres = define.TokenModifiers.declaration,
- }
+ if source.type == 'local' then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes['function'],
+ modifieres = define.TokenModifiers.declaration,
+ }
+ else
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes['function'],
+ }
+ end
return
end
-- 3. 特殊变量 | Special variableif source[1] == '_ENV' then
@@ -703,6 +711,14 @@ local Care = util.switch()
type = define.TokenTypes.namespace,
}
end)
+ : case 'doc.attr'
+ : call(function (source, options, results)
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.decorator,
+ }
+ end)
---@param state table
---@param results table
diff --git a/script/core/signature.lua b/script/core/signature.lua
index 63b0cd0d..98018b21 100644
--- a/script/core/signature.lua
+++ b/script/core/signature.lua
@@ -89,6 +89,39 @@ local function makeOneSignature(source, oop, index)
}
end
+local function isEventNotMatch(call, src)
+ if not call.args or not src.args then
+ return false
+ end
+ local literal, index
+ for i = 1, 2 do
+ if not call.args[i] then
+ break
+ end
+ literal = guide.getLiteral(call.args[i])
+ if literal then
+ index = i
+ break
+ end
+ end
+ if not literal then
+ return false
+ end
+ local event = src.args[index]
+ if not event or event.type ~= 'doc.type.arg' then
+ return false
+ end
+ if not event.extends
+ or #event.extends.types ~= 1 then
+ return false
+ end
+ local eventLiteral = event.extends.types[1] and guide.getLiteral(event.extends.types[1])
+ if eventLiteral == nil then
+ return false
+ end
+ return eventLiteral ~= literal
+end
+
---@async
local function makeSignatures(text, call, pos)
local func = call.node
@@ -139,7 +172,8 @@ local function makeSignatures(text, call, pos)
for src in node:eachObject() do
if (src.type == 'function' and not vm.isVarargFunctionWithOverloads(src))
or src.type == 'doc.type.function' then
- if not mark[src] then
+ if not mark[src]
+ and not isEventNotMatch(call, src) then
mark[src] = true
signs[#signs+1] = makeOneSignature(src, oop, index)
end
@@ -149,7 +183,8 @@ local function makeSignatures(text, call, pos)
if set.type == 'doc.class' then
for _, overload in ipairs(set.calls) do
local f = overload.overload
- if not mark[f] then
+ if not mark[f]
+ and not isEventNotMatch(call, src) then
mark[f] = true
signs[#signs+1] = makeOneSignature(f, oop, index)
end
diff --git a/script/file-uri.lua b/script/file-uri.lua
index 8e9dd938..8a075f7e 100644
--- a/script/file-uri.lua
+++ b/script/file-uri.lua
@@ -49,7 +49,7 @@ function m.encode(path)
--lower-case windows drive letters in /C:/fff or C:/fff
local start, finish, drive = path:find '/(%u):'
- if drive then
+ if drive and finish then
path = path:sub(1, start) .. drive:lower() .. path:sub(finish, -1)
end
@@ -102,6 +102,9 @@ function m.isValid(uri)
if path == '' then
return false
end
+ if scheme ~= 'file' then
+ return false
+ end
return true
end
diff --git a/script/files.lua b/script/files.lua
index 7998ceed..b7a20e8b 100644
--- a/script/files.lua
+++ b/script/files.lua
@@ -19,20 +19,20 @@ local sp = require 'bee.subprocess'
local pub = require 'pub'
---@class file
----@field uri uri
----@field content string
----@field ref? integer
----@field trusted? boolean
----@field rows? integer[]
----@field originText? string
----@field text string
----@field version? integer
----@field originLines? integer[]
----@field diffInfo? table[]
----@field cache table
----@field id integer
----@field state? parser.state
----@field compileCount integer
+---@field uri uri
+---@field ref? integer
+---@field trusted? boolean
+---@field rows? integer[]
+---@field originText? string
+---@field text? string
+---@field version? integer
+---@field originLines? integer[]
+---@field diffInfo? table[]
+---@field cache? table
+---@field id integer
+---@field state? parser.state
+---@field compileCount? integer
+---@field words? table
---@class files
---@field lazyCache? lazy-cacher
@@ -725,7 +725,8 @@ end
---@class parser.state
---@field diffInfo? table[]
---@field originLines? integer[]
----@field originText string
+---@field originText? string
+---@field lua? string
--- 获取文件语法树
---@param uri uri
diff --git a/script/global.d.lua b/script/global.d.lua
index f84ff0e4..b44d6371 100644
--- a/script/global.d.lua
+++ b/script/global.d.lua
@@ -55,6 +55,12 @@ DOC = ''
---@type string | '"Error"' | '"Warning"' | '"Information"' | '"Hint"'
CHECKLEVEL = 'Warning'
+--Where to write the check results (JSON).
+--
+--If nil, use `LOGPATH/check.json`.
+---@type string|nil
+CHECK_OUT_PATH = ''
+
---@type 'trace' | 'debug' | 'info' | 'warn' | 'error'
LOGLEVEL = 'warn'
@@ -77,3 +83,6 @@ jit = false
-- connect to client by socket
---@type integer
SOCKET = 0
+
+-- Allowing the use of the root directory or home directory as the workspace
+FORCE_ACCEPT_WORKSPACE = false
diff --git a/script/library.lua b/script/library.lua
index 3a9bbbc6..4446797a 100644
--- a/script/library.lua
+++ b/script/library.lua
@@ -469,34 +469,45 @@ end
local hasAsked = {}
---@async
-local function askFor3rd(uri, cfg)
+local function askFor3rd(uri, cfg, checkThirdParty)
if hasAsked[cfg.name] then
return nil
end
- hasAsked[cfg.name] = true
- local yes1 = lang.script.WINDOW_APPLY_WHIT_SETTING
- local yes2 = lang.script.WINDOW_APPLY_WHITOUT_SETTING
- local no = lang.script.WINDOW_DONT_SHOW_AGAIN
- local result = client.awaitRequestMessage('Info'
- , lang.script('WINDOW_ASK_APPLY_LIBRARY', cfg.name)
- , {yes1, yes2, no}
- )
- if not result then
- return nil
- end
- if result == yes1 then
+
+ if checkThirdParty == 'Apply' then
apply3rd(uri, cfg, false)
- elseif result == yes2 then
+ elseif checkThirdParty == 'ApplyInMemory' then
apply3rd(uri, cfg, true)
- else
- client.setConfig({
- {
- key = 'Lua.workspace.checkThirdParty',
- action = 'set',
- value = false,
- uri = uri,
- },
- }, false)
+ elseif checkThirdParty == 'Disable' then
+ return nil
+ elseif checkThirdParty == 'Ask' then
+ hasAsked[cfg.name] = true
+ local applyAndSetConfig = lang.script.WINDOW_APPLY_WHIT_SETTING
+ local applyInMemory = lang.script.WINDOW_APPLY_WHITOUT_SETTING
+ local dontShowAgain = lang.script.WINDOW_DONT_SHOW_AGAIN
+ local result = client.awaitRequestMessage('Info'
+ , lang.script('WINDOW_ASK_APPLY_LIBRARY', cfg.name)
+ , {applyAndSetConfig, applyInMemory, dontShowAgain}
+ )
+ if not result then
+ -- "If none got selected"
+ -- See: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#window_showMessageRequest
+ return nil
+ end
+ if result == applyAndSetConfig then
+ apply3rd(uri, cfg, false)
+ elseif result == applyInMemory then
+ apply3rd(uri, cfg, true)
+ else
+ client.setConfig({
+ {
+ key = 'Lua.workspace.checkThirdParty',
+ action = 'set',
+ value = 'Disable',
+ uri = uri,
+ },
+ }, false)
+ end
end
end
@@ -517,7 +528,7 @@ local function wholeMatch(a, b)
return true
end
-local function check3rdByWords(uri, configs)
+local function check3rdByWords(uri, configs, checkThirdParty)
if not files.isLua(uri) then
return
end
@@ -546,7 +557,7 @@ local function check3rdByWords(uri, configs)
log.info('Found 3rd library by word: ', word, uri, library, inspect(config.get(uri, 'Lua.workspace.library')))
---@async
await.call(function ()
- askFor3rd(uri, cfg)
+ askFor3rd(uri, cfg, checkThirdParty)
end)
return
end
@@ -556,7 +567,7 @@ local function check3rdByWords(uri, configs)
end, id)
end
-local function check3rdByFileName(uri, configs)
+local function check3rdByFileName(uri, configs, checkThirdParty)
local path = ws.getRelativePath(uri)
if not path then
return
@@ -582,7 +593,7 @@ local function check3rdByFileName(uri, configs)
log.info('Found 3rd library by filename: ', filename, uri, library, inspect(config.get(uri, 'Lua.workspace.library')))
---@async
await.call(function ()
- askFor3rd(uri, cfg)
+ askFor3rd(uri, cfg, checkThirdParty)
end)
return
end
@@ -597,8 +608,12 @@ local function check3rd(uri)
if ws.isIgnored(uri) then
return
end
- if not config.get(uri, 'Lua.workspace.checkThirdParty') then
+ local checkThirdParty = config.get(uri, 'Lua.workspace.checkThirdParty')
+ -- Backwards compatability: `checkThirdParty` used to be a boolean.
+ if not checkThirdParty or checkThirdParty == 'Disable' then
return
+ elseif checkThirdParty == true then
+ checkThirdParty = 'Ask'
end
local scp = scope.getScope(uri)
if not scp:get 'canCheckThirdParty' then
@@ -608,8 +623,8 @@ local function check3rd(uri)
if not thirdConfigs then
return
end
- check3rdByWords(uri, thirdConfigs)
- check3rdByFileName(uri, thirdConfigs)
+ check3rdByWords(uri, thirdConfigs, checkThirdParty)
+ check3rdByFileName(uri, thirdConfigs, checkThirdParty)
end
local function check3rdOfWorkspace(suri)
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index e09c958f..5321d9b8 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -470,7 +470,7 @@ end
local function parseLongString()
local start, finish, mark = sfind(Lua, '^(%[%=*%[)', Tokens[Index])
- if not mark then
+ if not start then
return nil
end
fastForwardToken(finish + 1)
@@ -2313,7 +2313,12 @@ local function parseFunction(isLocal, isAction)
local params
if func.name and func.name.type == 'getmethod' then
if func.name.type == 'getmethod' then
- params = {}
+ params = {
+ type = 'funcargs',
+ start = funcRight,
+ finish = funcRight,
+ parent = func
+ }
params[1] = createLocal {
start = funcRight,
finish = funcRight,
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 0d190ed5..4e71c832 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -56,7 +56,7 @@ local type = type
---@field returnIndex integer
---@field assignIndex integer
---@field docIndex integer
----@field docs parser.object[]
+---@field docs parser.object
---@field state table
---@field comment table
---@field optional boolean
@@ -74,7 +74,9 @@ local type = type
---@field hasBreak? true
---@field hasExit? true
---@field [integer] parser.object|any
----@field package _root parser.object
+---@field package _root parser.object
+---@field package _eachCache? parser.object[]
+---@field package _isGlobal? boolean
---@class guide
---@field debugMode boolean
@@ -154,10 +156,10 @@ local childMap = {
['unary'] = {1},
['doc'] = {'#'},
- ['doc.class'] = {'class', '#extends', '#signs', 'comment'},
+ ['doc.class'] = {'class', '#extends', '#signs', 'docAttr', 'comment'},
['doc.type'] = {'#types', 'name', 'comment'},
['doc.alias'] = {'alias', 'extends', 'comment'},
- ['doc.enum'] = {'enum', 'extends', 'comment'},
+ ['doc.enum'] = {'enum', 'extends', 'comment', 'docAttr'},
['doc.param'] = {'param', 'extends', 'comment'},
['doc.return'] = {'#returns', 'comment'},
['doc.field'] = {'field', 'extends', 'comment'},
@@ -180,6 +182,7 @@ local childMap = {
['doc.cast.block'] = {'extends'},
['doc.operator'] = {'op', 'exp', 'extends'},
['doc.meta'] = {'name'},
+ ['doc.attr'] = {'#names'},
}
---@type table<string, fun(obj: parser.object, list: parser.object[])>
@@ -736,7 +739,7 @@ end
--- 遍历所有指定类型的source
---@param ast parser.object
---@param type string
----@param callback fun(src: parser.object)
+---@param callback fun(src: parser.object): any
---@return any
function m.eachSourceType(ast, type, callback)
local cache = getSourceTypeCache(ast)
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 9f3b8fd5..8e7d334f 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -156,6 +156,7 @@ Symbol <- ({} {
---@field calls? parser.object[]
---@field generics? parser.object[]
---@field generic? parser.object
+---@field docAttr? parser.object
local function parseTokens(text, offset)
Ci = 0
@@ -252,6 +253,40 @@ local function nextSymbolOrError(symbol)
return false
end
+local function parseDocAttr(parent)
+ if not checkToken('symbol', '(', 1) then
+ return nil
+ end
+ nextToken()
+
+ local attrs = {
+ type = 'doc.attr',
+ parent = parent,
+ start = getStart(),
+ finish = getStart(),
+ names = {},
+ }
+
+ while true do
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ goto continue
+ end
+ local name = parseName('doc.attr.name', attrs)
+ if not name then
+ break
+ end
+ attrs.names[#attrs.names+1] = name
+ attrs.finish = name.finish
+ ::continue::
+ end
+
+ nextSymbolOrError(')')
+ attrs.finish = getFinish()
+
+ return attrs
+end
+
local function parseIndexField(parent)
if not checkToken('symbol', '[', 1) then
return nil
@@ -806,6 +841,7 @@ local docSwitch = util.switch()
operators = {},
calls = {},
}
+ result.docAttr = parseDocAttr(result)
result.class = parseName('doc.class.name', result)
if not result.class then
pushWarning {
@@ -1108,13 +1144,13 @@ local docSwitch = util.switch()
end)
: case 'meta'
: call(function ()
- local requireName = parseName('doc.meta.name')
- return {
+ local meta = {
type = 'doc.meta',
- name = requireName,
start = getFinish(),
finish = getFinish(),
}
+ meta.name = parseName('doc.meta.name', meta)
+ return meta
end)
: case 'version'
: call(function ()
@@ -1428,17 +1464,22 @@ local docSwitch = util.switch()
end)
: case 'enum'
: call(function ()
+ local attr = parseDocAttr()
local name = parseName('doc.enum.name')
if not name then
return nil
end
local result = {
- type = 'doc.enum',
- start = name.start,
- finish = name.finish,
- enum = name,
+ type = 'doc.enum',
+ start = name.start,
+ finish = name.finish,
+ enum = name,
+ docAttr = attr,
}
name.parent = result
+ if attr then
+ attr.parent = result
+ end
return result
end)
: case 'private'
@@ -1534,7 +1575,7 @@ local function buildLuaDoc(comment)
parseTokens(doc, startOffset + startPos)
local result, rests = convertTokens(doc)
if result then
- result.range = comment.finish
+ result.range = math.max(comment.finish, result.finish)
local finish = result.firstFinish or result.finish
if rests then
for _, rest in ipairs(rests) do
@@ -1669,7 +1710,9 @@ local function bindDocWithSource(doc, source)
if not source.bindDocs then
source.bindDocs = {}
end
- source.bindDocs[#source.bindDocs+1] = doc
+ if source.bindDocs[#source.bindDocs] ~= doc then
+ source.bindDocs[#source.bindDocs+1] = doc
+ end
doc.bindSource = source
end
@@ -1820,7 +1863,7 @@ local function bindDocsBetween(sources, binded, start, finish)
or src.type == 'setindex'
or src.type == 'setmethod'
or src.type == 'function'
- or src.type == 'table'
+ or src.type == 'return'
or src.type == '...' then
if bindDoc(src, binded) then
ok = true
@@ -1940,7 +1983,7 @@ local bindDocAccept = {
'local' , 'setlocal' , 'setglobal',
'setfield' , 'setmethod' , 'setindex' ,
'tablefield', 'tableindex', 'self' ,
- 'function' , 'table' , '...' ,
+ 'function' , 'return' , '...' ,
}
local function bindDocs(state)
diff --git a/script/plugins/ffi/c-parser/ctypes.lua b/script/plugins/ffi/c-parser/ctypes.lua
index 81d0ccf6..115f78ab 100644
--- a/script/plugins/ffi/c-parser/ctypes.lua
+++ b/script/plugins/ffi/c-parser/ctypes.lua
@@ -149,7 +149,10 @@ end
local function add_to_fields(lst, field_src, fields)
if type(field_src) == "table" and not field_src.ids then
assert(field_src.type.type == "union")
- local subfields = get_fields(lst, field_src.type.fields)
+ local subfields, err = get_fields(lst, field_src.type.fields)
+ if not subfields then
+ return nil, err
+ end
for _, subfield in ipairs(subfields) do
table.insert(fields, subfield)
end
diff --git a/script/plugins/ffi/cdefRerence.lua b/script/plugins/ffi/cdefRerence.lua
index 14643f0f..54a8c2a7 100644
--- a/script/plugins/ffi/cdefRerence.lua
+++ b/script/plugins/ffi/cdefRerence.lua
@@ -20,7 +20,7 @@ end
return function ()
local ffi_state
for uri in files.eachFile() do
- if find(uri, "ffi.lua", 0, true) and find(uri, "lua-language-server", 0, true) then
+ if find(uri, "ffi.lua", 0, true) and find(uri, "meta", 0, true) then
ffi_state = files.getState(uri)
break
end
diff --git a/script/proto/converter.lua b/script/proto/converter.lua
index a723face..e86e4904 100644
--- a/script/proto/converter.lua
+++ b/script/proto/converter.lua
@@ -207,6 +207,14 @@ function m.setOffsetEncoding(encoding)
offsetEncoding = encoding:lower():gsub('%-', '')
end
+---@param s string
+---@param i? integer
+---@param j? integer
+---@return integer
+function m.len(s, i, j)
+ return encoder.len(offsetEncoding, s, i, j)
+end
+
---@class proto.command
---@field title string
---@field command string
diff --git a/script/proto/diagnostic.lua b/script/proto/diagnostic.lua
index 8175a2c5..61b8ff4b 100644
--- a/script/proto/diagnostic.lua
+++ b/script/proto/diagnostic.lua
@@ -62,6 +62,7 @@ m.register {
'missing-return-value',
'redundant-return-value',
'missing-return',
+ 'missing-fields',
} {
group = 'unbalanced',
severity = 'Warning',
@@ -76,6 +77,7 @@ m.register {
'param-type-mismatch',
'cast-type-mismatch',
'return-type-mismatch',
+ 'inject-field',
} {
group = 'type-check',
severity = 'Warning',
diff --git a/script/proto/proto.lua b/script/proto/proto.lua
index 73544ffc..d01c8f36 100644
--- a/script/proto/proto.lua
+++ b/script/proto/proto.lua
@@ -8,6 +8,9 @@ local define = require 'proto.define'
local json = require 'json'
local inspect = require 'inspect'
local thread = require 'bee.thread'
+local fs = require 'bee.filesystem'
+local net = require 'service.net'
+local timer = require 'timer'
local reqCounter = util.counter()
@@ -32,8 +35,7 @@ m.ability = {}
m.waiting = {}
m.holdon = {}
m.mode = 'stdio'
----@type bee.socket.fd
-m.fd = nil
+m.client = nil
function m.getMethodName(proto)
if proto.method:sub(1, 2) == '$/' then
@@ -54,7 +56,8 @@ function m.send(data)
if m.mode == 'stdio' then
io.write(buf)
elseif m.mode == 'socket' then
- m.fd:send(buf)
+ m.client:write(buf)
+ net.update()
end
end
@@ -237,13 +240,37 @@ function m.listen(mode, socketPort)
io.stdout:setvbuf 'no'
pub.task('loadProtoByStdio')
elseif mode == 'socket' then
- local rfd = assert(socket('tcp'))
- rfd:connect('127.0.0.1', socketPort)
- local wfd1, wfd2 = socket.pair()
- m.fd = wfd1
+ local unixFolder = LOGPATH .. '/unix'
+ fs.create_directories(fs.path(unixFolder))
+ local unixPath = unixFolder .. '/' .. tostring(socketPort)
+
+ local server = net.listen('unix', unixPath)
+
+ assert(server)
+
+ local dummyClient = {
+ buf = '',
+ write = function (self, data)
+ self.buf = self.buf.. data
+ end,
+ update = function () end,
+ }
+ m.client = dummyClient
+
+ local t = timer.loop(0.1, function ()
+ net.update()
+ end)
+
+ function server:on_accept(client)
+ t:remove()
+ m.client = client
+ client:write(dummyClient.buf)
+ net.update()
+ end
+
pub.task('loadProtoBySocket', {
- wfd = wfd2:detach(),
- rfd = rfd:detach(),
+ port = socketPort,
+ unixPath = unixPath,
})
end
end
diff --git a/script/provider/formatting.lua b/script/provider/formatting.lua
index 73f9a534..ea94db08 100644
--- a/script/provider/formatting.lua
+++ b/script/provider/formatting.lua
@@ -82,16 +82,13 @@ function m.updateNonStandardSymbols(symbols)
return
end
- local eqTokens = {}
- for _, token in ipairs(symbols) do
- if token:find("=") and token ~= "!=" then
- table.insert(eqTokens, token)
+ for _, symbol in ipairs(symbols) do
+ if symbol == "//" then
+ codeFormat.set_clike_comments_symbol()
end
end
- if #eqTokens ~= 0 then
- codeFormat.set_nonstandard_symbol()
- end
+ codeFormat.set_nonstandard_symbol()
end
config.watch(function(uri, key, value)
diff --git a/script/provider/provider.lua b/script/provider/provider.lua
index 787cfeb8..a791e980 100644
--- a/script/provider/provider.lua
+++ b/script/provider/provider.lua
@@ -712,11 +712,11 @@ m.register 'completionItem/resolve' {
--await.setPriority(1000)
local state = files.getState(uri)
if not state then
- return nil
+ return item
end
local resolved = core.resolve(id)
if not resolved then
- return nil
+ return item
end
item.detail = resolved.detail or item.detail
item.documentation = resolved.description and {
@@ -772,8 +772,8 @@ m.register 'textDocument/signatureHelp' {
for j, param in ipairs(result.params) do
parameters[j] = {
label = {
- param.label[1],
- param.label[2],
+ converter.len(result.label, 1, param.label[1]),
+ converter.len(result.label, 1, param.label[2]),
}
}
end
@@ -904,7 +904,7 @@ m.register 'textDocument/codeLens' {
resolveProvider = true,
}
},
- abortByFileUpdate = true,
+ --abortByFileUpdate = true,
---@async
function (params)
local uri = files.getRealUri(params.textDocument.uri)
diff --git a/script/service/net.lua b/script/service/net.lua
index 61603d79..2019406e 100644
--- a/script/service/net.lua
+++ b/script/service/net.lua
@@ -1,42 +1,39 @@
local socket = require "bee.socket"
+local select = require "bee.select"
+local selector = select.create()
+local SELECT_READ <const> = select.SELECT_READ
+local SELECT_WRITE <const> = select.SELECT_WRITE
-local readfds = {}
-local writefds = {}
-local map = {}
-
-local function FD_SET(set, fd)
- for i = 1, #set do
- if fd == set[i] then
- return
- end
+local function fd_set_read(s)
+ if s._flags & SELECT_READ ~= 0 then
+ return
end
- set[#set+1] = fd
+ s._flags = s._flags | SELECT_READ
+ selector:event_mod(s._fd, s._flags)
end
-local function FD_CLR(set, fd)
- for i = 1, #set do
- if fd == set[i] then
- set[i] = set[#set]
- set[#set] = nil
- return
- end
+local function fd_clr_read(s)
+ if s._flags & SELECT_READ == 0 then
+ return
end
+ s._flags = s._flags & (~SELECT_READ)
+ selector:event_mod(s._fd, s._flags)
end
-local function fd_set_read(fd)
- FD_SET(readfds, fd)
-end
-
-local function fd_clr_read(fd)
- FD_CLR(readfds, fd)
-end
-
-local function fd_set_write(fd)
- FD_SET(writefds, fd)
+local function fd_set_write(s)
+ if s._flags & SELECT_WRITE ~= 0 then
+ return
+ end
+ s._flags = s._flags | SELECT_WRITE
+ selector:event_mod(s._fd, s._flags)
end
-local function fd_clr_write(fd)
- FD_CLR(writefds, fd)
+local function fd_clr_write(s)
+ if s._flags & SELECT_WRITE == 0 then
+ return
+ end
+ s._flags = s._flags & (~SELECT_WRITE)
+ selector:event_mod(s._fd, s._flags)
end
local function on_event(self, name, ...)
@@ -49,8 +46,8 @@ end
local function close(self)
local fd = self._fd
on_event(self, "close")
+ selector:event_del(fd)
fd:close()
- map[fd] = nil
end
local stream_mt = {}
@@ -69,7 +66,7 @@ function stream:write(data)
return
end
if self._writebuf == "" then
- fd_set_write(self._fd)
+ fd_set_write(self)
end
self._writebuf = self._writebuf .. data
end
@@ -79,35 +76,17 @@ end
function stream:close()
if not self.shutdown_r then
self.shutdown_r = true
- fd_clr_read(self._fd)
+ fd_clr_read(self)
end
if self.shutdown_w or self._writebuf == "" then
self.shutdown_w = true
- fd_clr_write(self._fd)
+ fd_clr_write(self)
close(self)
end
end
-function stream:update(timeout)
- local fd = self._fd
- local r = {fd}
- local w = r
- if self._writebuf == "" then
- w = nil
- end
- local rd, wr = socket.select(r, w, timeout or 0)
- if rd then
- if #rd > 0 then
- self:select_r()
- end
- if #wr > 0 then
- self:select_w()
- end
- end
-end
local function close_write(self)
- fd_clr_write(self._fd)
+ fd_clr_write(self)
if self.shutdown_r then
- fd_clr_read(self._fd)
close(self)
end
end
@@ -125,6 +104,7 @@ function stream:select_w()
if n == nil then
self.shutdown_w = true
close_write(self)
+ elseif n == false then
else
self._writebuf = self._writebuf:sub(n + 1)
if self._writebuf == "" then
@@ -132,26 +112,43 @@ function stream:select_w()
end
end
end
+local function update_stream(s, event)
+ if event & SELECT_READ ~= 0 then
+ s:select_r()
+ end
+ if event & SELECT_WRITE ~= 0 then
+ s:select_w()
+ end
+end
local function accept_stream(fd)
- local self = setmetatable({
+ local s = setmetatable({
_fd = fd,
+ _flags = SELECT_READ,
_event = {},
_writebuf = "",
shutdown_r = false,
shutdown_w = false,
}, stream_mt)
- map[fd] = self
- fd_set_read(fd)
- return self
-end
-local function connect_stream(self)
- setmetatable(self, stream_mt)
- fd_set_read(self._fd)
- if self._writebuf ~= "" then
- self:select_w()
+ selector:event_add(fd, SELECT_READ, function (event)
+ update_stream(s, event)
+ end)
+ return s
+end
+local function connect_stream(s)
+ setmetatable(s, stream_mt)
+ selector:event_del(s._fd)
+ if s._writebuf ~= "" then
+ s._flags = SELECT_READ | SELECT_WRITE
+ selector:event_add(s._fd, SELECT_READ | SELECT_WRITE, function (event)
+ update_stream(s, event)
+ end)
+ s:select_w()
else
- fd_clr_write(self._fd)
+ s._flags = SELECT_READ
+ selector:event_add(s._fd, SELECT_READ, function (event)
+ update_stream(s, event)
+ end)
end
end
@@ -169,35 +166,32 @@ function listen:is_closed()
end
function listen:close()
self.shutdown_r = true
- fd_clr_read(self._fd)
close(self)
end
-function listen:update(timeout)
- local fd = self._fd
- local r = {fd}
- local rd = socket.select(r, nil, timeout or 0)
- if rd then
- if #rd > 0 then
- self:select_r()
- end
- end
-end
-function listen:select_r()
- local newfd = self._fd:accept()
- if newfd:status() then
- local news = accept_stream(newfd)
- on_event(self, "accept", news)
- end
-end
local function new_listen(fd)
local s = {
_fd = fd,
+ _flags = SELECT_READ,
_event = {},
shutdown_r = false,
shutdown_w = true,
}
- map[fd] = s
- fd_set_read(fd)
+ selector:event_add(fd, SELECT_READ, function ()
+ local newfd, err = fd:accept()
+ if not newfd then
+ on_event(s, "error", err)
+ return
+ end
+ local ok, err = newfd:status()
+ if not ok then
+ on_event(s, "error", err)
+ return
+ end
+ if newfd:status() then
+ local news = accept_stream(newfd)
+ on_event(s, "accept", news)
+ end
+ end)
return setmetatable(s, listen_mt)
end
@@ -220,39 +214,27 @@ function connect:is_closed()
end
function connect:close()
self.shutdown_w = true
- fd_clr_write(self._fd)
close(self)
end
-function connect:update(timeout)
- local fd = self._fd
- local w = {fd}
- local rd, wr = socket.select(nil, w, timeout or 0)
- if rd then
- if #wr > 0 then
- self:select_w()
- end
- end
-end
-function connect:select_w()
- local ok, err = self._fd:status()
- if ok then
- connect_stream(self)
- on_event(self, "connect")
- else
- on_event(self, "error", err)
- self:close()
- end
-end
local function new_connect(fd)
local s = {
_fd = fd,
+ _flags = SELECT_WRITE,
_event = {},
_writebuf = "",
shutdown_r = false,
shutdown_w = false,
}
- map[fd] = s
- fd_set_write(fd)
+ selector:event_add(fd, SELECT_WRITE, function ()
+ local ok, err = fd:status()
+ if ok then
+ connect_stream(s)
+ on_event(s, "connect")
+ else
+ on_event(s, "error", err)
+ s:close()
+ end
+ end)
return setmetatable(s, connect_mt)
end
@@ -292,18 +274,8 @@ function m.connect(protocol, ...)
end
function m.update(timeout)
- local rd, wr = socket.select(readfds, writefds, timeout or 0)
- if rd then
- for i = 1, #rd do
- local fd = rd[i]
- local s = map[fd]
- s:select_r()
- end
- for i = 1, #wr do
- local fd = wr[i]
- local s = map[fd]
- s:select_w()
- end
+ for func, event in selector:wait(timeout or 0) do
+ func(event)
end
end
diff --git a/script/service/service.lua b/script/service/service.lua
index 7011ec4f..b6056390 100644
--- a/script/service/service.lua
+++ b/script/service/service.lua
@@ -257,7 +257,6 @@ function m.lockCache()
if err then
log.error(err)
end
- pub.task('removeCaches', cacheDir)
end
function m.start()
diff --git a/script/utility.lua b/script/utility.lua
index be945791..936726f9 100644
--- a/script/utility.lua
+++ b/script/utility.lua
@@ -190,38 +190,48 @@ function m.dump(tbl, option)
end
--- 递归判断A与B是否相等
----@param a any
----@param b any
+---@param valueA any
+---@param valueB any
---@return boolean
-function m.equal(a, b)
- local tp1 = type(a)
- local tp2 = type(b)
- if tp1 ~= tp2 then
- return false
- end
- if tp1 == 'table' then
- local mark = {}
- for k, v in pairs(a) do
- mark[k] = true
- local res = m.equal(v, b[k])
- if not res then
- return false
- end
+function m.equal(valueA, valueB)
+ local hasChecked = {}
+
+ local function equal(a, b)
+ local tp1 = type(a)
+ local tp2 = type(b)
+ if tp1 ~= tp2 then
+ return false
end
- for k in pairs(b) do
- if not mark[k] then
- return false
+ if tp1 == 'table' then
+ if hasChecked[a] then
+ return true
+ end
+ hasChecked[a] = true
+ local mark = {}
+ for k, v in pairs(a) do
+ mark[k] = true
+ local res = equal(v, b[k])
+ if not res then
+ return false
+ end
+ end
+ for k in pairs(b) do
+ if not mark[k] then
+ return false
+ end
end
- end
- return true
- elseif tp1 == 'number' then
- if mathAbs(a - b) <= 1e-10 then
return true
+ elseif tp1 == 'number' then
+ if mathAbs(a - b) <= 1e-10 then
+ return true
+ end
+ return tostring(a) == tostring(b)
+ else
+ return a == b
end
- return tostring(a) == tostring(b)
- else
- return a == b
end
+
+ return equal(valueA, valueB)
end
local function sortTable(tbl)
@@ -283,6 +293,9 @@ end
--- 读取文件
---@param path string
+---@param keepBom? boolean
+---@return string? text
+---@return string? errMsg
function m.loadFile(path, keepBom)
local f, e = ioOpen(path, 'rb')
if not f then
@@ -308,6 +321,8 @@ end
--- 写入文件
---@param path string
---@param content string
+---@return boolean ok
+---@return string? errMsg
function m.saveFile(path, content)
local f, e = ioOpen(path, "wb")
@@ -357,6 +372,7 @@ end
--- 深拷贝(不处理元表)
---@param source table
---@param target? table
+---@return table
function m.deepCopy(source, target)
local mark = {}
local function copy(a, b)
@@ -379,6 +395,8 @@ function m.deepCopy(source, target)
end
--- 序列化
+---@param t table
+---@return table
function m.unpack(t)
local result = {}
local tid = 0
@@ -406,6 +424,8 @@ function m.unpack(t)
end
--- 反序列化
+---@param t table
+---@return table
function m.pack(t)
local cache = {}
local function pack(id)
@@ -518,18 +538,25 @@ function m.utf8Len(str, start, finish)
return len
end
-function m.revertTable(t)
- local len = #t
+-- 把数组中的元素顺序*原地*反转
+---@param arr any[]
+---@return any[]
+function m.revertArray(arr)
+ local len = #arr
if len <= 1 then
- return t
+ return arr
end
for x = 1, len // 2 do
local y = len - x + 1
- t[x], t[y] = t[y], t[x]
+ arr[x], arr[y] = arr[y], arr[x]
end
- return t
+ return arr
end
+-- 创建一个value-key表
+---@generic K, V
+---@param t table<K, V>
+---@return table<V, K>
function m.revertMap(t)
local nt = {}
for k, v in pairs(t) do
@@ -624,6 +651,11 @@ function m.eachLine(text, keepNL)
end
end
+---@alias SortByScoreCallback fun(o: any): integer
+
+-- 按照分数排序,分数越高越靠前
+---@param tbl any[]
+---@param callbacks SortByScoreCallback | SortByScoreCallback[]
function m.sortByScore(tbl, callbacks)
if type(callbacks) ~= 'table' then
callbacks = { callbacks }
@@ -651,6 +683,16 @@ function m.sortByScore(tbl, callbacks)
end)
end
+---@param arr any[]
+---@return SortByScoreCallback
+function m.sortCallbackOfIndex(arr)
+ ---@type table<any, integer>
+ local indexMap = m.revertMap(arr)
+ return function (v)
+ return - indexMap[v]
+ end
+end
+
---裁剪字符串
---@param str string
---@param mode? '"left"'|'"right"'
@@ -733,6 +775,7 @@ function switchMT:has(name)
end
---@param name string
+---@param ... any
---@return ...
function switchMT:__call(name, ...)
local callback = self.map[name] or self._default
@@ -752,6 +795,8 @@ function m.switch()
end
---@param f async fun()
+---@param name string
+---@return any, boolean
function m.getUpvalue(f, name)
for i = 1, 999 do
local uname, value = getupvalue(f, i)
@@ -819,6 +864,7 @@ end
---@param t table
---@param sorter boolean|function
+---@return any[]
function m.getTableKeys(t, sorter)
local keys = {}
for k in pairs(t) do
@@ -841,6 +887,15 @@ function m.arrayHas(array, value)
return false
end
+function m.arrayIndexOf(array, value)
+ for i = 1, #array do
+ if array[i] == value then
+ return i
+ end
+ end
+ return nil
+end
+
function m.arrayInsert(array, value)
if not m.arrayHas(array, value) then
array[#array+1] = value
@@ -873,4 +928,24 @@ function m.cacheReturn(func)
end
end
+---@param a table
+---@param b table
+---@return table
+function m.tableMerge(a, b)
+ for k, v in pairs(b) do
+ a[k] = v
+ end
+ return a
+end
+
+---@param a any[]
+---@param b any[]
+---@return any[]
+function m.arrayMerge(a, b)
+ for i = 1, #b do
+ a[#a+1] = b[i]
+ end
+ return a
+end
+
return m
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 6b4636fc..8a1fa96a 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -7,10 +7,14 @@ local files = require 'files'
local vm = require 'vm.vm'
---@class parser.object
----@field _compiledNodes boolean
----@field _node vm.node
----@field cindex integer
----@field func parser.object
+---@field _compiledNodes boolean
+---@field _node vm.node
+---@field cindex integer
+---@field func parser.object
+---@field hideView boolean
+---@field package _returns? parser.object[]
+---@field package _callReturns? parser.object[]
+---@field package _asCache? parser.object[]
-- 该函数有副作用,会给source绑定node!
---@param source parser.object
@@ -28,6 +32,12 @@ function vm.bindDocs(source)
end
if doc.type == 'doc.class' then
vm.setNode(source, vm.compileNode(doc))
+ for j = i + 1, #docs do
+ local overload = docs[j]
+ if overload.type == 'doc.overload' then
+ overload.overload.hideView = true
+ end
+ end
return true
end
if doc.type == 'doc.param' then
@@ -55,6 +65,9 @@ function vm.bindDocs(source)
vm.setNode(source, vm.compileNode(ast))
return true
end
+ if doc.type == 'doc.overload' then
+ vm.setNode(source, vm.compileNode(doc))
+ end
end
return false
end
@@ -473,6 +486,7 @@ function vm.getReturnOfFunction(func, index)
func._returns = {}
end
if not func._returns[index] then
+ ---@diagnostic disable-next-line: missing-fields
func._returns[index] = {
type = 'function.return',
parent = func,
@@ -570,6 +584,7 @@ local function getReturn(func, index, args)
end
if not func._callReturns[index] then
local call = func.parent
+ ---@diagnostic disable-next-line: missing-fields
func._callReturns[index] = {
type = 'call.return',
parent = call,
@@ -756,6 +771,11 @@ function vm.selectNode(list, index)
if not exp then
return vm.createNode(vm.declareGlobal('type', 'nil')), nil
end
+
+ if vm.bindDocs(list) then
+ return vm.compileNode(list), exp
+ end
+
---@type vm.node?
local result
if exp.type == 'call' then
@@ -862,52 +882,69 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
end
end
- for n in callNode:eachObject() do
- if n.type == 'function' then
- ---@cast n parser.object
- local sign = vm.getSign(n)
+ ---@param n parser.object
+ local function dealDocFunc(n)
+ local myEvent
+ if n.args[eventIndex] then
+ local argNode = vm.compileNode(n.args[eventIndex])
+ myEvent = argNode:get(1)
+ end
+ if not myEvent
+ or not eventMap
+ or myIndex <= eventIndex
+ or myEvent.type ~= 'doc.type.string'
+ or eventMap[myEvent[1]] then
local farg = getFuncArg(n, myIndex)
if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
- if fn.type == 'doc.type.function' then
- ---@cast fn parser.object
- if sign then
- local generic = vm.createGeneric(fn, sign)
- local args = {}
- for i = fixIndex + 1, myIndex - 1 do
- args[#args+1] = call.args[i]
- end
- local resolvedNode = generic:resolve(guide.getUri(call), args)
- vm.setNode(arg, resolvedNode)
- goto CONTINUE
- end
- end
vm.setNode(arg, fn)
- ::CONTINUE::
end
end
end
end
- if n.type == 'doc.type.function' then
- ---@cast n parser.object
- local myEvent
- if n.args[eventIndex] then
- local argNode = vm.compileNode(n.args[eventIndex])
- myEvent = argNode:get(1)
- end
- if not myEvent
- or not eventMap
- or myIndex <= eventIndex
- or myEvent.type ~= 'doc.type.string'
- or eventMap[myEvent[1]] then
- local farg = getFuncArg(n, myIndex)
- if farg then
- for fn in vm.compileNode(farg):eachObject() do
- if isValidCallArgNode(arg, fn) then
- vm.setNode(arg, fn)
+ end
+
+ ---@param n parser.object
+ local function dealFunction(n)
+ local sign = vm.getSign(n)
+ local farg = getFuncArg(n, myIndex)
+ if farg then
+ for fn in vm.compileNode(farg):eachObject() do
+ if isValidCallArgNode(arg, fn) then
+ if fn.type == 'doc.type.function' then
+ ---@cast fn parser.object
+ if sign then
+ local generic = vm.createGeneric(fn, sign)
+ local args = {}
+ for i = fixIndex + 1, myIndex - 1 do
+ args[#args+1] = call.args[i]
+ end
+ local resolvedNode = generic:resolve(guide.getUri(call), args)
+ vm.setNode(arg, resolvedNode)
+ goto CONTINUE
end
end
+ vm.setNode(arg, fn)
+ ::CONTINUE::
+ end
+ end
+ end
+ end
+
+ for n in callNode:eachObject() do
+ if n.type == 'function' then
+ ---@cast n parser.object
+ dealFunction(n)
+ elseif n.type == 'doc.type.function' then
+ ---@cast n parser.object
+ dealDocFunc(n)
+ elseif n.type == 'global' and n.cate == 'type' then
+ ---@cast n vm.global
+ local overloads = vm.getOverloadsByTypeName(n.name, guide.getUri(arg))
+ if overloads then
+ for _, func in ipairs(overloads) do
+ dealDocFunc(func)
end
end
end
@@ -1020,6 +1057,7 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(source.value))
end
end
+
-- function x.y(self, ...) --> function x:y(...)
if source[1] == 'self'
and not hasMarkDoc
@@ -1031,6 +1069,7 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(setfield.node))
end
end
+
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
local func = source.parent.parent
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
@@ -1055,6 +1094,7 @@ local function compileLocal(source)
vm.setNode(source, vm.declareGlobal('type', 'any'))
end
end
+
-- for x in ... do
if source.parent.type == 'in' then
compileForVars(source.parent, source)
@@ -1094,6 +1134,12 @@ local function compileLocal(source)
end
end
+ if source.value
+ and source.value.type == 'nil'
+ and not myNode:hasKnownType() then
+ vm.setNode(source, vm.compileNode(source.value))
+ end
+
myNode.hasDefined = hasMarkDoc or hasMarkParam or hasMarkValue
end
@@ -1140,6 +1186,9 @@ local compilerSwitch = util.switch()
end)
: case 'table'
: call(function (source)
+ if vm.bindAs(source) then
+ return
+ end
vm.setNode(source, source)
if source.parent.type == 'callargs' then
@@ -1147,6 +1196,16 @@ local compilerSwitch = util.switch()
vm.compileCallArg(source, call)
end
+ if source.parent.type == 'return' then
+ local myIndex = util.arrayIndexOf(source.parent, source)
+ ---@cast myIndex -?
+ local parentNode = vm.selectNode(source.parent, myIndex)
+ if not parentNode:isEmpty() then
+ vm.setNode(source, parentNode)
+ return
+ end
+ end
+
if source.parent.type == 'setglobal'
or source.parent.type == 'local'
or source.parent.type == 'setlocal'
@@ -1648,7 +1707,7 @@ local compilerSwitch = util.switch()
if state.type == 'doc.return'
or state.type == 'doc.param' then
local func = state.bindSource
- if func.type == 'function' then
+ if func and func.type == 'function' then
local node = guide.getFunctionSelfNode(func)
if node then
vm.setNode(source, vm.compileNode(node))
diff --git a/script/vm/doc.lua b/script/vm/doc.lua
index a6ea248f..6ac39910 100644
--- a/script/vm/doc.lua
+++ b/script/vm/doc.lua
@@ -5,7 +5,11 @@ local vm = require 'vm.vm'
local config = require 'config'
---@class parser.object
----@field package _castTargetHead parser.object | vm.global | false
+---@field package _castTargetHead? parser.object | vm.global | false
+---@field package _validVersions? table<string, boolean>
+---@field package _deprecated? parser.object | false
+---@field package _async? boolean
+---@field package _nodiscard? boolean
---获取class与alias
---@param suri uri
@@ -467,3 +471,18 @@ function vm.getCastTargetHead(doc)
end
return nil
end
+
+---@param doc parser.object
+---@param key string
+---@return boolean
+function vm.docHasAttr(doc, key)
+ if not doc.docAttr then
+ return false
+ end
+ for _, name in ipairs(doc.docAttr.names) do
+ if name[1] == key then
+ return true
+ end
+ end
+ return false
+end
diff --git a/script/vm/function.lua b/script/vm/function.lua
index bdd7e229..c6df6349 100644
--- a/script/vm/function.lua
+++ b/script/vm/function.lua
@@ -372,8 +372,14 @@ function vm.isVarargFunctionWithOverloads(func)
if not func.args then
return false
end
- if not func.args[1] or func.args[1].type ~= '...' then
- return false
+ if func.args[1] and func.args[1].type == 'self' then
+ if not func.args[2] or func.args[2].type ~= '...' then
+ return false
+ end
+ else
+ if not func.args[1] or func.args[1].type ~= '...' then
+ return false
+ end
end
if not func.bindDocs then
return false
diff --git a/script/vm/global.lua b/script/vm/global.lua
index c1b5f320..e830f6d8 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -1,6 +1,7 @@
local util = require 'utility'
local scope = require 'workspace.scope'
local guide = require 'parser.guide'
+local config = require 'config'
---@class vm
local vm = require 'vm.vm'
@@ -371,16 +372,36 @@ local compilerGlobalSwitch = util.switch()
return
end
source._enums = {}
- for _, field in ipairs(tbl) do
- if field.type == 'tablefield' then
- source._enums[#source._enums+1] = field
- local subType = vm.declareGlobal('type', name .. '.' .. field.field[1], uri)
- subType:addSet(uri, field)
- elseif field.type == 'tableindex' then
- source._enums[#source._enums+1] = field
- if field.index.type == 'string' then
- local subType = vm.declareGlobal('type', name .. '.' .. field.index[1], uri)
+ if vm.docHasAttr(source, 'key') then
+ for _, field in ipairs(tbl) do
+ if field.type == 'tablefield' then
+ source._enums[#source._enums+1] = {
+ type = 'doc.type.string',
+ start = field.field.start,
+ finish = field.field.finish,
+ [1] = field.field[1],
+ }
+ elseif field.type == 'tableindex' then
+ source._enums[#source._enums+1] = {
+ type = 'doc.type.string',
+ start = field.index.start,
+ finish = field.index.finish,
+ [1] = field.index[1],
+ }
+ end
+ end
+ else
+ for _, field in ipairs(tbl) do
+ if field.type == 'tablefield' then
+ source._enums[#source._enums+1] = field
+ local subType = vm.declareGlobal('type', name .. '.' .. field.field[1], uri)
subType:addSet(uri, field)
+ elseif field.type == 'tableindex' then
+ source._enums[#source._enums+1] = field
+ if field.index.type == 'string' then
+ local subType = vm.declareGlobal('type', name .. '.' .. field.index[1], uri)
+ subType:addSet(uri, field)
+ end
end
end
end
@@ -518,6 +539,33 @@ function vm.hasGlobalSets(suri, cate, name)
return true
end
+---@param src parser.object
+local function checkIsUndefinedGlobal(src)
+ local key = src[1]
+
+ local uri = guide.getUri(src)
+ local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
+ local rspecial = config.get(uri, 'Lua.runtime.special')
+
+ local node = src.node
+ return src.type == 'getglobal' and key and not (
+ dglobals[key] or
+ rspecial[key] or
+ node.tag ~= '_ENV' or
+ vm.hasGlobalSets(uri, 'variable', key)
+ )
+end
+
+---@param src parser.object
+---@return boolean
+function vm.isUndefinedGlobal(src)
+ local node = vm.compileNode(src)
+ if node.undefinedGlobal == nil then
+ node.undefinedGlobal = checkIsUndefinedGlobal(src)
+ end
+ return node.undefinedGlobal
+end
+
---@param source parser.object
function compileObject(source)
if source._globalNode ~= nil then
@@ -593,6 +641,7 @@ function vm.getGlobalBase(source)
end
local name = global:asKeyName()
if not root._globalBaseMap[name] then
+ ---@diagnostic disable-next-line: missing-fields
root._globalBaseMap[name] = {
type = 'globalbase',
parent = root,
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 94fdfd88..f2673ed3 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -386,9 +386,11 @@ function mt:_computeViews(uri)
self.views = {}
for n in self.node:eachObject() do
- local view = viewNodeSwitch(n.type, n, self, uri)
- if view then
- self.views[view] = true
+ if not n.hideView then
+ local view = viewNodeSwitch(n.type, n, self, uri)
+ if view then
+ self.views[view] = true
+ end
end
end
@@ -565,11 +567,12 @@ function vm.viewKey(source, uri)
return vm.viewKey(source.types[1], uri)
else
local key = vm.getInfer(source):view(uri)
- return '[' .. key .. ']'
+ return '[' .. key .. ']', key
end
end
if source.type == 'tableindex'
- or source.type == 'setindex' then
+ or source.type == 'setindex'
+ or source.type == 'getindex' then
local index = source.index
local name = vm.getInfer(index):viewLiterals()
if not name then
@@ -587,7 +590,11 @@ function vm.viewKey(source, uri)
return vm.viewKey(source.name, uri)
end
if source.type == 'doc.type.name' then
- return '[' .. source[1] .. ']'
+ return '[' .. source[1] .. ']', source[1]
+ end
+ if source.type == 'doc.type.string' then
+ local name = util.viewString(source[1], source[2])
+ return ('[%s]'):format(name), name
end
local key = vm.getKeyName(source)
if key == nil then
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 0ffd8c70..bc1dfcb1 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -15,6 +15,7 @@ vm.nodeCache = setmetatable({}, util.MODE_K)
---@field [integer] vm.node.object
---@field [vm.node.object] true
---@field fields? table<vm.node|string, vm.node>
+---@field undefinedGlobal boolean?
local mt = {}
mt.__index = mt
mt.id = 0
diff --git a/script/vm/operator.lua b/script/vm/operator.lua
index bc8703c6..7ce2b30d 100644
--- a/script/vm/operator.lua
+++ b/script/vm/operator.lua
@@ -126,6 +126,7 @@ vm.unarySwich = util.switch()
if result == nil then
vm.setNode(source, vm.declareGlobal('type', 'boolean'))
else
+ ---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'boolean',
start = source.start,
@@ -155,6 +156,7 @@ vm.unarySwich = util.switch()
vm.setNode(source, node or vm.declareGlobal('type', 'number'))
end
else
+ ---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'number',
start = source.start,
@@ -171,6 +173,7 @@ vm.unarySwich = util.switch()
local node = vm.runOperator('bnot', source[1])
vm.setNode(source, node or vm.declareGlobal('type', 'integer'))
else
+ ---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'integer',
start = source.start,
@@ -223,6 +226,7 @@ vm.binarySwitch = util.switch()
if source.op.type == '~=' then
result = not result
end
+ ---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'boolean',
start = source.start,
@@ -247,6 +251,7 @@ vm.binarySwitch = util.switch()
or op == '&' and a & b
or op == '|' and a | b
or op == '~' and a ~ b
+ ---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'integer',
start = source.start,
@@ -285,6 +290,7 @@ vm.binarySwitch = util.switch()
or op == '%' and a % b
or op == '//' and a // b
or op == '^' and a ^ b
+ ---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = (op == '//' or math.type(result) == 'integer') and 'integer' or 'number',
start = source.start,
@@ -364,6 +370,7 @@ vm.binarySwitch = util.switch()
end
end
end
+ ---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'string',
start = source.start,
@@ -407,6 +414,7 @@ vm.binarySwitch = util.switch()
or op == '<' and a < b
or op == '>=' and a >= b
or op == '<=' and a <= b
+ ---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'boolean',
start = source.start,
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 1f434475..38cb2242 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -142,13 +142,15 @@ function mt:resolve(uri, args)
end
if object.type == 'doc.type.function' then
for i, arg in ipairs(object.args) do
- for n in node:eachObject() do
- if n.type == 'function'
- or n.type == 'doc.type.function' then
- ---@cast n parser.object
- local farg = n.args and n.args[i]
- if farg then
- resolve(arg.extends, vm.compileNode(farg))
+ if arg.extends then
+ for n in node:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ ---@cast n parser.object
+ local farg = n.args and n.args[i]
+ if farg then
+ resolve(arg.extends, vm.compileNode(farg))
+ end
end
end
end
@@ -254,7 +256,7 @@ function mt:resolve(uri, args)
local argNode = vm.compileNode(arg)
local knownTypes, genericNames = getSignInfo(sign)
if not isAllResolved(genericNames) then
- local newArgNode = buildArgNode(argNode,sign, knownTypes)
+ local newArgNode = buildArgNode(argNode, sign, knownTypes)
resolve(sign, newArgNode)
end
end
diff --git a/script/vm/type.lua b/script/vm/type.lua
index 8382eb86..545d2de5 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -143,6 +143,9 @@ end
---@param errs? typecheck.err[]
---@return boolean?
local function checkChildEnum(childName, parent , uri, mark, errs)
+ if mark[childName] then
+ return
+ end
local childClass = vm.getGlobal('type', childName)
if not childClass then
return nil
@@ -157,11 +160,14 @@ local function checkChildEnum(childName, parent , uri, mark, errs)
if not enums then
return nil
end
+ mark[childName] = true
for _, enum in ipairs(enums) do
if not vm.isSubType(uri, vm.compileNode(enum), parent, mark ,errs) then
+ mark[childName] = nil
return false
end
end
+ mark[childName] = nil
return true
end
@@ -752,7 +758,7 @@ function vm.viewTypeErrorMessage(uri, errs)
lines[#lines+1] = '- ' .. line
end
end
- util.revertTable(lines)
+ util.revertArray(lines)
if #lines > 15 then
lines[13] = ('...(+%d)'):format(#lines - 15)
table.move(lines, #lines - 2, #lines, 14)
@@ -761,3 +767,25 @@ function vm.viewTypeErrorMessage(uri, errs)
return table.concat(lines, '\n')
end
end
+
+---@param name string
+---@param uri uri
+---@return parser.object[]?
+function vm.getOverloadsByTypeName(name, uri)
+ local global = vm.getGlobal('type', name)
+ if not global then
+ return nil
+ end
+ local results
+ for _, set in ipairs(global:getSets(uri)) do
+ for _, doc in ipairs(set.bindGroup) do
+ if doc.type == 'doc.overload' then
+ if not results then
+ results = {}
+ end
+ results[#results+1] = doc.overload
+ end
+ end
+ end
+ return results
+end
diff --git a/script/vm/visible.lua b/script/vm/visible.lua
index e550280f..d13ecf1f 100644
--- a/script/vm/visible.lua
+++ b/script/vm/visible.lua
@@ -7,9 +7,10 @@ local glob = require 'glob'
---@class parser.object
---@field package _visibleType? parser.visibleType
----@param source parser.object
----@return parser.visibleType
-function vm.getVisibleType(source)
+local function getVisibleType(source)
+ if guide.isLiteral(source) then
+ return 'public'
+ end
if source._visibleType then
return source._visibleType
end
@@ -55,6 +56,27 @@ function vm.getVisibleType(source)
return 'public'
end
+---@class vm.node
+---@field package _visibleType parser.visibleType
+
+---@param source parser.object
+---@return parser.visibleType
+function vm.getVisibleType(source)
+ local node = vm.compileNode(source)
+ if node._visibleType then
+ return node._visibleType
+ end
+ for _, def in ipairs(vm.getDefs(source)) do
+ local visible = getVisibleType(def)
+ if visible ~= 'public' then
+ node._visibleType = visible
+ return visible
+ end
+ end
+ node._visibleType = 'public'
+ return 'public'
+end
+
---@param source parser.object
---@return vm.global?
function vm.getParentClass(source)
diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua
index 3e85e0fc..97518e84 100644
--- a/script/workspace/workspace.lua
+++ b/script/workspace/workspace.lua
@@ -50,8 +50,10 @@ function m.create(uri)
m.folders[#m.folders+1] = scp
if uri == furi.encode '/'
or uri == furi.encode(os.getenv 'HOME' or '') then
- client.showMessage('Error', lang.script('WORKSPACE_NOT_ALLOWED', furi.decode(uri)))
- scp:set('bad root', true)
+ if not FORCE_ACCEPT_WORKSPACE then
+ client.showMessage('Error', lang.script('WORKSPACE_NOT_ALLOWED', furi.decode(uri)))
+ scp:set('bad root', true)
+ end
end
end
@@ -469,10 +471,6 @@ function m.flushFiles(scp)
for uri in pairs(cachedUris) do
files.delRef(uri)
end
- collectgarbage()
- collectgarbage()
- -- TODO: wait maillist
- collectgarbage 'restart'
end
---@param scp scope
@@ -493,6 +491,8 @@ end
---@async
---@param scp scope
function m.awaitReload(scp)
+ await.unique('workspace reload:' .. scp:getName())
+ await.sleep(0.1)
scp:set('ready', false)
scp:set('nativeMatcher', nil)
scp:set('libraryMatcher', nil)