summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/client.lua3
-rw-r--r--script/config/loader.lua6
-rw-r--r--script/core/completion/completion.lua88
-rw-r--r--script/core/definition.lua4
-rw-r--r--script/core/diagnostics/close-non-object.lua13
-rw-r--r--script/core/diagnostics/duplicate-doc-field.lua1
-rw-r--r--script/core/diagnostics/global-in-nil-env.lua2
-rw-r--r--script/core/diagnostics/init.lua2
-rw-r--r--script/core/diagnostics/lowercase-global.lua25
-rw-r--r--script/core/diagnostics/missing-parameter.lua73
-rw-r--r--script/core/diagnostics/need-check-nil.lua39
-rw-r--r--script/core/diagnostics/no-unknown.lua4
-rw-r--r--script/core/diagnostics/not-yieldable.lua5
-rw-r--r--script/core/diagnostics/redundant-parameter.lua48
-rw-r--r--script/core/diagnostics/undefined-field.lua3
-rw-r--r--script/core/diagnostics/undefined-global.lua3
-rw-r--r--script/core/diagnostics/unused-function.lua152
-rw-r--r--script/core/diagnostics/unused-vararg.lua5
-rw-r--r--script/core/formatting.lua2
-rw-r--r--script/core/hint.lua3
-rw-r--r--script/core/hover/args.lua33
-rw-r--r--script/core/hover/description.lua5
-rw-r--r--script/core/hover/init.lua24
-rw-r--r--script/core/hover/label.lua9
-rw-r--r--script/core/hover/name.lua4
-rw-r--r--script/core/hover/return.lua12
-rw-r--r--script/core/hover/table.lua116
-rw-r--r--script/core/look-backward.lua7
-rw-r--r--script/core/matchkey.lua2
-rw-r--r--script/core/rangeformatting.lua2
-rw-r--r--script/core/rename.lua5
-rw-r--r--script/core/semantic-tokens.lua105
-rw-r--r--script/core/signature.lua3
-rw-r--r--script/core/type-definition.lua1
-rw-r--r--script/doctor.lua3
-rw-r--r--script/encoder/init.lua12
-rw-r--r--script/files.lua4
-rw-r--r--script/fs-utility.lua13
-rw-r--r--script/glob/gitignore.lua2
-rw-r--r--script/jsonc.lua603
-rw-r--r--script/language.lua4
-rw-r--r--script/log.lua5
-rw-r--r--script/parser/guide.lua52
-rw-r--r--script/parser/luadoc.lua1023
-rw-r--r--script/parser/newparser.lua50
-rw-r--r--script/proto/define.lua18
-rw-r--r--script/provider/diagnostic.lua18
-rw-r--r--script/provider/provider.lua23
-rw-r--r--script/pub/pub.lua2
-rw-r--r--script/service/telemetry.lua2
-rw-r--r--script/utility.lua12
-rw-r--r--script/vm/compiler.lua459
-rw-r--r--script/vm/def.lua15
-rw-r--r--script/vm/doc.lua11
-rw-r--r--script/vm/field.lua10
-rw-r--r--script/vm/generic.lua5
-rw-r--r--script/vm/global-manager.lua364
-rw-r--r--script/vm/global.lua431
-rw-r--r--script/vm/infer.lua115
-rw-r--r--script/vm/init.lua10
-rw-r--r--script/vm/library.lua21
-rw-r--r--script/vm/local-id.lua62
-rw-r--r--script/vm/local-manager.lua40
-rw-r--r--script/vm/manager.lua26
-rw-r--r--script/vm/node.lua260
-rw-r--r--script/vm/ref.lua6
-rw-r--r--script/vm/runner.lua444
-rw-r--r--script/vm/sign.lua29
-rw-r--r--script/vm/type.lua11
-rw-r--r--script/vm/value.lua30
-rw-r--r--script/vm/vm.lua1
-rw-r--r--script/workspace/loading.lua2
-rw-r--r--script/workspace/workspace.lua9
73 files changed, 3315 insertions, 1701 deletions
diff --git a/script/client.lua b/script/client.lua
index daa9bc52..d86fb4f2 100644
--- a/script/client.lua
+++ b/script/client.lua
@@ -248,6 +248,7 @@ local function tryModifyRC(uri, finalChanges, create)
end
local workspace = require 'workspace'
local path = workspace.getAbsolutePath(uri, '.luarc.json')
+ or workspace.getAbsolutePath(uri, '.luarc.jsonc')
if not path then
return false
end
@@ -318,7 +319,7 @@ local function tryModifyClientGlobal(finalChanges)
end
---@param changes config.change[]
----@param onlyMemory boolean
+---@param onlyMemory? boolean
function m.setConfig(changes, onlyMemory)
local finalChanges = {}
for _, change in ipairs(changes) do
diff --git a/script/config/loader.lua b/script/config/loader.lua
index c53f9399..30711dde 100644
--- a/script/config/loader.lua
+++ b/script/config/loader.lua
@@ -1,10 +1,10 @@
-local json = require 'json'
local proto = require 'proto'
local lang = require 'language'
local util = require 'utility'
local workspace = require 'workspace'
local scope = require 'workspace.scope'
local inspect = require 'inspect'
+local jsonc = require 'jsonc'
local function errorMessage(msg)
proto.notify('window/showMessage', {
@@ -29,7 +29,7 @@ function m.loadRCConfig(uri, filename)
scp:set('lastRCConfig', nil)
return nil
end
- local suc, res = pcall(json.decode, buf)
+ local suc, res = pcall(jsonc.decode, buf)
if not suc then
errorMessage(lang.script('CONFIG_LOAD_ERROR', res))
return scp:get('lastRCConfig')
@@ -55,7 +55,7 @@ function m.loadLocalConfig(uri, filename)
end
local firstChar = buf:match '%S'
if firstChar == '{' then
- local suc, res = pcall(json.decode, buf)
+ local suc, res = pcall(jsonc.decode, buf)
if not suc then
errorMessage(lang.script('CONFIG_LOAD_ERROR', res))
return scp:get('lastLocalConfig')
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua
index beff594c..d4c20c60 100644
--- a/script/core/completion/completion.lua
+++ b/script/core/completion/completion.lua
@@ -16,10 +16,8 @@ local rpath = require 'workspace.require-path'
local lang = require 'language'
local lookBackward = require 'core.look-backward'
local guide = require 'parser.guide'
-local infer = require 'vm.infer'
local await = require 'await'
local postfix = require 'core.completion.postfix'
-local globalMgr = require 'vm.global-manager'
local diagnosticModes = {
'disable-next-line',
@@ -186,11 +184,8 @@ local function buildFunctionSnip(source, value, oop)
end
local function buildDetail(source)
- if source.type == 'dummy' then
- return
- end
- local types = infer.getInfer(source):view()
- local literals = infer.getInfer(source):viewLiterals()
+ local types = vm.getInfer(source):view()
+ local literals = vm.getInfer(source):viewLiterals()
if literals then
return types .. ' = ' .. literals
else
@@ -228,9 +223,6 @@ end
---@async
local function buildDesc(source)
- if source.type == 'dummy' then
- return
- end
local desc = markdown()
local hover = getHover.get(source)
desc:add('md', hover)
@@ -310,8 +302,23 @@ local function checkLocal(state, word, position, results)
if name:sub(1, 1) == '@' then
goto CONTINUE
end
- if infer.getInfer(source):hasFunction() then
- for _, def in ipairs(vm.getDefs(source)) do
+ if vm.getInfer(source):hasFunction() then
+ local defs = vm.getDefs(source)
+ -- make sure `function` is before `doc.type.function`
+ local orders = {}
+ for i, def in ipairs(defs) do
+ if def.type == 'function' then
+ orders[def] = i - 20000
+ elseif def.type == 'doc.type.function' then
+ orders[def] = i - 10000
+ else
+ orders[def] = i
+ end
+ end
+ table.sort(defs, function (a, b)
+ return orders[a] < orders[b]
+ end)
+ for _, def in ipairs(defs) do
if def.type == 'function'
or def.type == 'doc.type.function' then
local funcLabel = name .. getParams(def, false)
@@ -358,7 +365,7 @@ local function checkModule(state, word, position, results)
local fileName = path:match '[^/\\]*$'
local stemName = fileName:gsub('%..+', '')
if not locals[stemName]
- and not globalMgr.hasGlobalSets(state.uri, 'variable', stemName)
+ and not vm.hasGlobalSets(state.uri, 'variable', stemName)
and not config.get(state.uri, 'Lua.diagnostics.globals')[stemName]
and stemName:match '^[%a_][%w_]*$'
and matchKey(word, stemName) then
@@ -505,7 +512,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent
})
return
end
- if oop and not infer.getInfer(src):hasFunction() then
+ if oop and not vm.getInfer(src):hasFunction() then
return
end
local literal = guide.getLiteral(value)
@@ -608,14 +615,14 @@ end
---@async
local function checkGlobal(state, word, startPos, position, parent, oop, results)
local locals = guide.getVisibleLocals(state.ast, position)
- local globals = globalMgr.getGlobalSets(state.uri, 'variable')
+ local globals = vm.getGlobalSets(state.uri, 'variable')
checkFieldOfRefs(globals, state, word, startPos, position, parent, oop, results, locals, 'global')
end
---@async
local function checkField(state, word, start, position, parent, oop, results)
if parent.tag == '_ENV' or parent.special == '_G' then
- local globals = globalMgr.getGlobalSets(state.uri, 'variable')
+ local globals = vm.getGlobalSets(state.uri, 'variable')
checkFieldOfRefs(globals, state, word, start, position, parent, oop, results)
else
local refs = vm.getFields(parent)
@@ -1124,7 +1131,7 @@ local function checkTypingEnum(state, position, defs, str, results)
or def.type == 'doc.type.integer'
or def.type == 'doc.type.boolean' then
enums[#enums+1] = {
- label = infer.viewObject(def),
+ label = vm.viewObject(def),
description = def.comment and def.comment.text,
kind = define.CompletionItemKind.EnumMember,
}
@@ -1413,7 +1420,7 @@ local function tryCallArg(state, position, results)
or src.type == 'doc.type.integer'
or src.type == 'doc.type.boolean' then
enums[#enums+1] = {
- label = infer.viewObject(src),
+ label = vm.viewObject(src),
description = src.comment,
kind = define.CompletionItemKind.EnumMember,
}
@@ -1432,7 +1439,7 @@ local function tryCallArg(state, position, results)
: string()
end
enums[#enums+1] = {
- label = infer.getInfer(src):view(),
+ label = vm.getInfer(src):view(),
description = description,
kind = define.CompletionItemKind.Function,
insertText = insertText,
@@ -1520,6 +1527,7 @@ local function tryluaDocCate(word, results)
'module',
'async',
'nodiscard',
+ 'cast',
} do
if matchKey(word, docType) then
results[#results+1] = {
@@ -1668,8 +1676,27 @@ local function tryluaDocBySource(state, position, source, results)
}
end
end
+ return true
elseif source.type == 'doc.module' then
collectRequireNames('require', state.uri, source.module or '', source, source.smark, position, results)
+ return true
+ elseif source.type == 'doc.cast.name' then
+ local locals = guide.getVisibleLocals(state.ast, position)
+ for name, loc in util.sortPairs(locals) do
+ if matchKey(source[1], name) then
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Variable,
+ id = stack(function () ---@async
+ return {
+ detail = buildDetail(loc),
+ description = buildDesc(loc),
+ }
+ end),
+ }
+ end
+ end
+ return true
end
return false
end
@@ -1764,6 +1791,22 @@ local function tryluaDocByErr(state, position, err, docState, results)
end
elseif err.type == 'LUADOC_MISS_MODULE_NAME' then
collectRequireNames('require', state.uri, '', docState, nil, position, results)
+ elseif err.type == 'LUADOC_MISS_LOCAL_NAME' then
+ local locals = guide.getVisibleLocals(state.ast, position)
+ for name, loc in util.sortPairs(locals) do
+ if name ~= '_ENV' then
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Variable,
+ id = stack(function () ---@async
+ return {
+ detail = buildDetail(loc),
+ description = buildDesc(loc),
+ }
+ end),
+ }
+ end
+ end
end
end
@@ -1775,14 +1818,14 @@ local function buildluaDocOfFunction(func)
local returns = {}
if func.args then
for _, arg in ipairs(func.args) do
- args[#args+1] = infer.getInfer(arg):view()
+ args[#args+1] = vm.getInfer(arg):view()
end
end
if func.returns then
for _, rtns in ipairs(func.returns) do
for n = 1, #rtns do
if not returns[n] then
- returns[n] = infer.getInfer(rtns[n]):view()
+ returns[n] = vm.getInfer(rtns[n]):view()
end
end
end
@@ -1882,6 +1925,9 @@ local function tryComment(state, position, results)
local doc = getluaDoc(state, position)
if not word then
local comment = getComment(state, position)
+ if not comment then
+ return
+ end
if comment.type == 'comment.short'
or comment.type == 'comment.cshort' then
if comment.text == '' then
diff --git a/script/core/definition.lua b/script/core/definition.lua
index b89aa751..e4868532 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -53,6 +53,7 @@ local accept = {
['doc.alias.name'] = true,
['doc.see.name'] = true,
['doc.see.field'] = true,
+ ['doc.cast.name'] = true,
}
local function checkRequire(source, offset)
@@ -133,6 +134,9 @@ return function (uri, offset)
local defs = vm.getDefs(source)
for _, src in ipairs(defs) do
+ if src.type == 'global' then
+ goto CONTINUE
+ end
local root = guide.getRoot(src)
if not root then
goto CONTINUE
diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua
index b9d3c485..c97014fa 100644
--- a/script/core/diagnostics/close-non-object.lua
+++ b/script/core/diagnostics/close-non-object.lua
@@ -1,6 +1,7 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
+local vm = require 'vm'
return function (uri, callback)
local state = files.getState(uri)
@@ -23,18 +24,16 @@ return function (uri, callback)
}
return
end
- if source.value.type == 'nil'
- or source.value.type == 'number'
- or source.value.type == 'integer'
- or source.value.type == 'boolean'
- or source.value.type == 'table'
- or source.value.type == 'function' then
+ local infer = vm.getInfer(source.value)
+ if not infer:hasClass()
+ and not infer:hasType 'nil'
+ and not infer:hasType 'table'
+ and infer:view('any', uri) ~= 'any' then
callback {
start = source.value.start,
finish = source.value.finish,
message = lang.script.DIAG_COSE_NON_OBJECT,
}
- return
end
end)
end
diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua
index 8d355aac..d4116b9b 100644
--- a/script/core/diagnostics/duplicate-doc-field.lua
+++ b/script/core/diagnostics/duplicate-doc-field.lua
@@ -1,6 +1,5 @@
local files = require 'files'
local lang = require 'language'
-local infer = require 'vm.infer'
local function getFieldEventName(doc)
if not doc.extends then
diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua
index d95963e4..334fd81a 100644
--- a/script/core/diagnostics/global-in-nil-env.lua
+++ b/script/core/diagnostics/global-in-nil-env.lua
@@ -16,7 +16,7 @@ return function (uri, callback)
local env = guide.getENV(root)
local nilDefs = {}
- if not env.ref then
+ if not env or not env.ref then
return
end
for _, ref in ipairs(env.ref) do
diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua
index 369a6ba2..b4ae3715 100644
--- a/script/core/diagnostics/init.lua
+++ b/script/core/diagnostics/init.lua
@@ -105,7 +105,7 @@ end
---@param uri uri
---@param isScopeDiag boolean
---@param response async fun(result: any)
----@param checked async fun(name: string)
+---@param checked? async fun(name: string)
return function (uri, isScopeDiag, response, checked)
local ast = files.getState(uri)
if not ast then
diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua
index d7032c13..d03e8c70 100644
--- a/script/core/diagnostics/lowercase-global.lua
+++ b/script/core/diagnostics/lowercase-global.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'parser.guide'
-local lang = require 'language'
-local config = require 'config'
-local vm = require 'vm'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local config = require 'config'
+local vm = require 'vm'
local function isDocClass(source)
if not source.bindDocs then
@@ -30,7 +30,7 @@ return function (uri, callback)
guide.eachSourceType(ast.ast, 'setglobal', function (source)
local name = guide.getKeyName(source)
- if definedGlobal[name] then
+ if not name or definedGlobal[name] then
return
end
local first = name:match '%w'
@@ -44,8 +44,17 @@ return function (uri, callback)
if isDocClass(source) then
return
end
- if vm.isGlobalLibraryName(name) then
- return
+ if definedGlobal[name] == nil then
+ definedGlobal[name] = false
+ local global = vm.getGlobal('variable', name)
+ if global then
+ for _, set in ipairs(global:getSets(uri)) do
+ if vm.isMetaFile(guide.getUri(set)) then
+ definedGlobal[name] = true
+ return
+ end
+ end
+ end
end
callback {
start = source.start,
diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua
new file mode 100644
index 00000000..698680ca
--- /dev/null
+++ b/script/core/diagnostics/missing-parameter.lua
@@ -0,0 +1,73 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+
+local function countCallArgs(source)
+ local result = 0
+ if not source.args then
+ return 0
+ end
+ result = result + #source.args
+ return result
+end
+
+---@return integer
+local function countFuncArgs(source)
+ if not source.args or #source.args == 0 then
+ return 0
+ end
+ local count = 0
+ for i = #source.args, 1, -1 do
+ local arg = source.args[i]
+ if arg.type ~= '...'
+ and not (arg.name and arg.name[1] =='...')
+ and not vm.compileNode(arg):isNullable() then
+ return i
+ end
+ end
+ return count
+end
+
+local function getFuncArgs(func)
+ local funcArgs
+ local defs = vm.getDefs(func)
+ for _, def in ipairs(defs) do
+ if def.type == 'function'
+ or def.type == 'doc.type.function' then
+ local args = countFuncArgs(def)
+ if not funcArgs or args < funcArgs then
+ funcArgs = args
+ end
+ end
+ end
+ return funcArgs
+end
+
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ guide.eachSourceType(state.ast, 'call', function (source)
+ local callArgs = countCallArgs(source)
+
+ local func = source.node
+ local funcArgs = getFuncArgs(func)
+
+ if not funcArgs then
+ return
+ end
+
+ local delta = callArgs - funcArgs
+ if delta >= 0 then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_MISS_ARGS', funcArgs, callArgs),
+ }
+ end)
+end
diff --git a/script/core/diagnostics/need-check-nil.lua b/script/core/diagnostics/need-check-nil.lua
new file mode 100644
index 00000000..98fdfd08
--- /dev/null
+++ b/script/core/diagnostics/need-check-nil.lua
@@ -0,0 +1,39 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ guide.eachSourceType(state.ast, 'getlocal', function (src)
+ local checkNil
+ local nxt = src.next
+ if nxt then
+ if nxt.type == 'getfield'
+ or nxt.type == 'getmethod'
+ or nxt.type == 'getindex'
+ or nxt.type == 'call' then
+ checkNil = true
+ end
+ end
+ local call = src.parent
+ if call and call.type == 'call' and call.node == src then
+ checkNil = true
+ end
+ if not checkNil then
+ return
+ end
+ local node = vm.compileNode(src)
+ if node:hasFalsy() then
+ callback {
+ start = src.start,
+ finish = src.finish,
+ message = lang.script('DIAG_NEED_CHECK_NIL'),
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/no-unknown.lua b/script/core/diagnostics/no-unknown.lua
index 2199b6a8..48aab5da 100644
--- a/script/core/diagnostics/no-unknown.lua
+++ b/script/core/diagnostics/no-unknown.lua
@@ -1,7 +1,7 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
-local infer = require 'vm.infer'
+local vm = require 'vm'
return function (uri, callback)
local ast = files.getState(uri)
@@ -20,7 +20,7 @@ return function (uri, callback)
and source.type ~= 'tableindex' then
return
end
- if infer.getInfer(source):view() == 'unknown' then
+ if vm.getInfer(source):view() == 'unknown' then
callback {
start = source.start,
finish = source.finish,
diff --git a/script/core/diagnostics/not-yieldable.lua b/script/core/diagnostics/not-yieldable.lua
index 0588bbde..a1c84276 100644
--- a/script/core/diagnostics/not-yieldable.lua
+++ b/script/core/diagnostics/not-yieldable.lua
@@ -3,7 +3,6 @@ local await = require 'await'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
-local infer = require 'vm.infer'
local function isYieldAble(defs, i)
local hasFuncDef
@@ -12,7 +11,7 @@ local function isYieldAble(defs, i)
local arg = def.args and def.args[i]
if arg then
hasFuncDef = true
- if infer.getInfer(arg):hasType 'any'
+ if vm.getInfer(arg):hasType 'any'
or vm.isAsync(arg, true)
or arg.type == '...' then
return true
@@ -23,7 +22,7 @@ local function isYieldAble(defs, i)
local arg = def.args and def.args[i]
if arg then
hasFuncDef = true
- if infer.getInfer(arg.extends):hasType 'any'
+ if vm.getInfer(arg.extends):hasType 'any'
or vm.isAsync(arg.extends, true) then
return true
end
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
index 4adf169e..41781df8 100644
--- a/script/core/diagnostics/redundant-parameter.lua
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -2,7 +2,6 @@ local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
-local define = require 'proto.define'
local function countCallArgs(source)
local result = 0
@@ -14,64 +13,40 @@ local function countCallArgs(source)
end
local function countFuncArgs(source)
- local result = 0
if not source.args or #source.args == 0 then
- return result
- end
- if source.args[#source.args].type == '...' then
- return math.maxinteger
- end
- result = result + #source.args
- return result
-end
-
-local function countOverLoadArgs(source, doc)
- local result = 0
- local func = doc.overload
- if not func.args or #func.args == 0 then
- return result
+ return 0
end
- if func.args[#func.args].type == '...' then
+ local lastArg = source.args[#source.args]
+ if lastArg.type == '...'
+ or (lastArg.name and lastArg.name[1] == '...') then
return math.maxinteger
+ else
+ return #source.args
end
- result = result + #func.args
- return result
end
local function getFuncArgs(func)
local funcArgs
local defs = vm.getDefs(func)
for _, def in ipairs(defs) do
- if def.value then
- def = def.value
- end
- if def.type == 'function' then
+ if def.type == 'function'
+ or def.type == 'doc.type.function' then
local args = countFuncArgs(def)
if not funcArgs or args > funcArgs then
funcArgs = args
end
- if def.bindDocs then
- for _, doc in ipairs(def.bindDocs) do
- if doc.type == 'doc.overload' then
- args = countOverLoadArgs(def, doc)
- if not funcArgs or args > funcArgs then
- funcArgs = args
- end
- end
- end
- end
end
end
return funcArgs
end
return function (uri, callback)
- local ast = files.getState(uri)
- if not ast then
+ local state = files.getState(uri)
+ if not state then
return
end
- guide.eachSourceType(ast.ast, 'call', function (source)
+ guide.eachSourceType(state.ast, 'call', function (source)
local callArgs = countCallArgs(source)
if callArgs == 0 then
return
@@ -97,7 +72,6 @@ return function (uri, callback)
callback {
start = arg.start,
finish = arg.finish,
- tags = { define.DiagnosticTag.Unnecessary },
message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs)
}
end
diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua
index 025c217a..41fcda48 100644
--- a/script/core/diagnostics/undefined-field.lua
+++ b/script/core/diagnostics/undefined-field.lua
@@ -3,7 +3,6 @@ local vm = require 'vm'
local lang = require 'language'
local guide = require 'parser.guide'
local await = require 'await'
-local infer = require 'vm.infer'
local skipCheckClass = {
['unknown'] = true,
@@ -35,7 +34,7 @@ return function (uri, callback)
local node = src.node
if node then
local ok
- for view in infer.getInfer(node):eachView() do
+ for view in vm.getInfer(node):eachView() do
if not skipCheckClass[view] then
ok = true
break
diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua
index 139fa74f..bd0aae69 100644
--- a/script/core/diagnostics/undefined-global.lua
+++ b/script/core/diagnostics/undefined-global.lua
@@ -4,7 +4,6 @@ local lang = require 'language'
local config = require 'config'
local guide = require 'parser.guide'
local await = require 'await'
-local globalMgr = require 'vm.global-manager'
local requireLike = {
['include'] = true,
@@ -41,7 +40,7 @@ return function (uri, callback)
return
end
if cache[key] == nil then
- cache[key] = globalMgr.hasGlobalSets(uri, 'variable', key)
+ cache[key] = vm.hasGlobalSets(uri, 'variable', key)
end
if cache[key] then
return
diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua
index 79cb16e2..813ac804 100644
--- a/script/core/diagnostics/unused-function.lua
+++ b/script/core/diagnostics/unused-function.lua
@@ -18,75 +18,107 @@ local function isToBeClosed(source)
return false
end
----@async
-return function (uri, callback)
- local ast = files.getState(uri)
- if not ast then
- return
+---@param source parser.object
+local function isValidFunction(source)
+ if not source then
+ return false
+ end
+ if source.type == 'main' then
+ return false
+ end
+ local parent = source.parent
+ if not parent then
+ return false
+ end
+ if parent.type ~= 'local'
+ and parent.type ~= 'setlocal' then
+ return false
+ end
+ if isToBeClosed(parent) then
+ return false
end
+ return true
+end
- local cache = {}
+---@async
+local function collect(ast, white, roots, links)
---@async
- local function checkFunction(source)
- if not source then
+ guide.eachSourceType(ast, 'function', function (src)
+ await.delay()
+ if not isValidFunction(src) then
return
end
- if cache[source] ~= nil then
- return cache[source]
- end
- cache[source] = false
- local parent = source.parent
- if not parent then
- return false
- end
- if parent.type ~= 'local'
- and parent.type ~= 'setlocal' then
- return false
- end
- if isToBeClosed(parent) then
- return false
+ local loc = src.parent
+ if loc.type == 'setlocal' then
+ loc = loc.node
end
- await.delay()
- if parent.type == 'setlocal' then
- parent = parent.node
- end
- local refs = parent.ref
- local hasGet
- if refs then
- for _, src in ipairs(refs) do
- if guide.isGet(src) then
- local func = guide.getParentFunction(src)
- if not checkFunction(func) then
- hasGet = true
- break
- end
+ for _, ref in ipairs(loc.ref or {}) do
+ if ref.type == 'getlocal' then
+ local func = guide.getParentFunction(ref)
+ if not isValidFunction(func) or roots[func] then
+ roots[src] = true
+ return
end
+ if not links[func] then
+ links[func] = {}
+ end
+ links[func][#links[func]+1] = src
end
end
- if not hasGet then
- if client.isVSCode() then
- callback {
- start = source.start,
- finish = source.finish,
- tags = { define.DiagnosticTag.Unnecessary },
- message = lang.script.DIAG_UNUSED_FUNCTION,
- }
- else
- callback {
- start = source.keyword[1],
- finish = source.keyword[2],
- tags = { define.DiagnosticTag.Unnecessary },
- message = lang.script.DIAG_UNUSED_FUNCTION,
- }
- end
- cache[source] = true
- return true
- end
- return false
+ white[src] = true
+ end)
+
+ return white, roots, links
+end
+
+local function turnBlack(source, black, white, links)
+ if black[source] then
+ return
end
+ black[source] = true
+ white[source] = nil
+ for _, link in ipairs(links[source] or {}) do
+ turnBlack(link, black, white, links)
+ end
+end
- -- 只检查局部函数
- guide.eachSourceType(ast.ast, 'function', function (source) ---@async
- checkFunction(source)
- end)
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ if vm.isMetaFile(uri) then
+ return
+ end
+
+ local black = {}
+ local white = {}
+ local roots = {}
+ local links = {}
+
+ collect(state.ast, white, roots, links)
+
+ for source in pairs(roots) do
+ turnBlack(source, black, white, links)
+ end
+
+ for source in pairs(white) do
+ if client.isVSCode() then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_UNUSED_FUNCTION,
+ }
+ else
+ callback {
+ start = source.keyword[1],
+ finish = source.keyword[2],
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_UNUSED_FUNCTION,
+ }
+ end
+ end
end
diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua
index 2e07e1ee..ce033cf3 100644
--- a/script/core/diagnostics/unused-vararg.lua
+++ b/script/core/diagnostics/unused-vararg.lua
@@ -2,6 +2,7 @@ local files = require 'files'
local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
+local vm = require 'vm'
return function (uri, callback)
local ast = files.getState(uri)
@@ -9,6 +10,10 @@ return function (uri, callback)
return
end
+ if vm.isMetaFile(uri) then
+ return
+ end
+
guide.eachSourceType(ast.ast, 'function', function (source)
local args = source.args
if not args then
diff --git a/script/core/formatting.lua b/script/core/formatting.lua
index 49da6861..b52854a4 100644
--- a/script/core/formatting.lua
+++ b/script/core/formatting.lua
@@ -3,7 +3,7 @@ local files = require("files")
local log = require("log")
return function(uri, options)
- local text = files.getText(uri)
+ local text = files.getOriginText(uri)
local ast = files.getState(uri)
local status, formattedText = codeFormat.format(uri, text, options)
diff --git a/script/core/hint.lua b/script/core/hint.lua
index f6774d2a..f97cdcec 100644
--- a/script/core/hint.lua
+++ b/script/core/hint.lua
@@ -1,5 +1,4 @@
local files = require 'files'
-local infer = require 'vm.infer'
local vm = require 'vm'
local config = require 'config'
local guide = require 'parser.guide'
@@ -39,7 +38,7 @@ local function typeHint(uri, results, start, finish)
end
end
await.delay()
- local view = infer.getInfer(source):view()
+ local view = vm.getInfer(source):view()
if view == 'any'
or view == 'unknown'
or view == 'nil' then
diff --git a/script/core/hover/args.lua b/script/core/hover/args.lua
index a53136b0..c485d9b9 100644
--- a/script/core/hover/args.lua
+++ b/script/core/hover/args.lua
@@ -1,17 +1,5 @@
local guide = require 'parser.guide'
-local infer = require 'vm.infer'
-
-local function optionalArg(arg)
- if not arg.bindDocs then
- return false
- end
- local name = arg[1]
- for _, doc in ipairs(arg.bindDocs) do
- if doc.type == 'doc.param' and doc.param[1] == name then
- return doc.optional
- end
- end
-end
+local vm = require 'vm'
local function asFunction(source)
local args = {}
@@ -21,7 +9,7 @@ local function asFunction(source)
methodDef = true
end
if methodDef then
- args[#args+1] = ('self: %s'):format(infer.getInfer(parent.node):view 'any')
+ args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view 'any')
end
if source.args then
for i = 1, #source.args do
@@ -31,18 +19,25 @@ local function asFunction(source)
end
local name = arg.name or guide.getKeyName(arg)
if name then
+ local argNode = vm.compileNode(arg)
+ local optional
+ if argNode:isOptional() then
+ optional = true
+ argNode = argNode:copy()
+ argNode:removeOptional()
+ end
args[#args+1] = ('%s%s: %s'):format(
name,
- optionalArg(arg) and '?' or '',
- infer.getInfer(arg):view 'any'
+ optional and '?' or '',
+ vm.getInfer(argNode):view('any', guide.getUri(source))
)
elseif arg.type == '...' then
args[#args+1] = ('%s: %s'):format(
'...',
- infer.getInfer(arg):view 'any'
+ vm.getInfer(arg):view 'any'
)
else
- args[#args+1] = ('%s'):format(infer.getInfer(arg):view 'any')
+ args[#args+1] = ('%s'):format(vm.getInfer(arg):view 'any')
end
::CONTINUE::
end
@@ -61,7 +56,7 @@ local function asDocFunction(source)
args[i] = ('%s%s: %s'):format(
name,
arg.optional and '?' or '',
- arg.extends and infer.getInfer(arg.extends):view 'any' or 'any'
+ arg.extends and vm.getInfer(arg.extends):view 'any' or 'any'
)
end
return args
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
index 03f6128a..e9267c0f 100644
--- a/script/core/hover/description.lua
+++ b/script/core/hover/description.lua
@@ -6,7 +6,6 @@ local lang = require 'language'
local util = require 'utility'
local guide = require 'parser.guide'
local rpath = require 'workspace.require-path'
-local infer = require 'vm.infer'
local function collectRequire(mode, literal, uri)
local result, searchers
@@ -153,7 +152,7 @@ local function buildEnumChunk(docType, name)
local types = {}
local lines = {}
for _, tp in ipairs(vm.getDefs(docType)) do
- types[#types+1] = infer.getInfer(tp):view()
+ types[#types+1] = vm.getInfer(tp):view()
if tp.type == 'doc.type.string'
or tp.type == 'doc.type.integer'
or tp.type == 'doc.type.boolean' then
@@ -175,7 +174,7 @@ local function buildEnumChunk(docType, name)
(enum.default and '->')
or (enum.additional and '+>')
or ' |',
- infer.viewObject(enum)
+ vm.viewObject(enum)
)
if enum.comment then
local first = true
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
index bc2f40eb..7231944a 100644
--- a/script/core/hover/init.lua
+++ b/script/core/hover/init.lua
@@ -5,7 +5,6 @@ local getDesc = require 'core.hover.description'
local util = require 'utility'
local findSource = require 'core.find-source'
local markdown = require 'provider.markdown'
-local infer = require 'vm.infer'
local guide = require 'parser.guide'
---@async
@@ -40,9 +39,24 @@ local function getHover(source)
end
local oop
- if infer.getInfer(source):view() == 'function' then
+ if vm.getInfer(source):view() == 'function' then
+ local defs = vm.getDefs(source)
+ -- make sure `function` is before `doc.type.function`
+ local orders = {}
+ for i, def in ipairs(defs) do
+ if def.type == 'function' then
+ orders[def] = i - 20000
+ elseif def.type == 'doc.type.function' then
+ orders[def] = i - 10000
+ else
+ orders[def] = i
+ end
+ end
+ table.sort(defs, function (a, b)
+ return orders[a] < orders[b]
+ end)
local hasFunc
- for _, def in ipairs(vm.getDefs(source)) do
+ for _, def in ipairs(defs) do
if guide.isOOP(def) then
oop = true
end
@@ -58,6 +72,9 @@ local function getHover(source)
else
addHover(source, true, oop)
for _, def in ipairs(vm.getDefs(source)) do
+ if def.type == 'global' then
+ goto CONTINUE
+ end
if guide.isOOP(def) then
oop = true
end
@@ -67,6 +84,7 @@ local function getHover(source)
isFunction = true
end
addHover(def, isFunction, oop)
+ ::CONTINUE::
end
end
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
index 8224e9d3..2bbfe806 100644
--- a/script/core/hover/label.lua
+++ b/script/core/hover/label.lua
@@ -2,7 +2,6 @@ local buildName = require 'core.hover.name'
local buildArgs = require 'core.hover.args'
local buildReturn = require 'core.hover.return'
local buildTable = require 'core.hover.table'
-local infer = require 'vm.infer'
local vm = require 'vm'
local util = require 'utility'
local lang = require 'language'
@@ -34,7 +33,7 @@ local function asDocTypeName(source)
return '(class) ' .. doc.class[1]
end
if doc.type == 'doc.alias' then
- return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', infer.getInfer(doc.extends):view())
+ return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view())
end
end
end
@@ -42,7 +41,7 @@ end
---@async
local function asValue(source, title)
local name = buildName(source, false) or ''
- local ifr = infer.getInfer(source)
+ local ifr = vm.getInfer(source)
local type = ifr:view()
local literal = ifr:viewLiterals()
local cont = buildTable(source)
@@ -140,7 +139,7 @@ local function asDocFieldName(source)
break
end
end
- local view = infer.getInfer(source.extends):view()
+ local view = vm.getInfer(source.extends):view()
if not class then
return ('(field) ?.%s: %s'):format(name, view)
end
@@ -180,7 +179,7 @@ local function asNumber(source)
if not text then
return nil
end
- local raw = text:sub(source.start, source.finish)
+ local raw = text:sub(source.start + 1, source.finish)
if not raw or not raw:find '[^%-%d%.]' then
return nil
end
diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua
index 905c5ec7..f8473638 100644
--- a/script/core/hover/name.lua
+++ b/script/core/hover/name.lua
@@ -1,5 +1,5 @@
-local infer = require 'vm.infer'
local guide = require 'parser.guide'
+local vm = require 'vm'
local buildName
@@ -19,7 +19,7 @@ end
local function asField(source, oop)
local class
if source.node.type ~= 'getglobal' then
- class = infer.getInfer(source.node):viewClass()
+ class = vm.getInfer(source.node):viewClass()
end
local node = class
or buildName(source.node, false)
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
index 77710148..3d8a94a5 100644
--- a/script/core/hover/return.lua
+++ b/script/core/hover/return.lua
@@ -1,5 +1,3 @@
-local infer = require 'vm.infer'
-local guide = require 'parser.guide'
local vm = require 'vm.vm'
---@param source parser.object
@@ -65,10 +63,9 @@ local function asFunction(source)
local rtn = vm.getReturnOfFunction(source, i)
local doc = docs[i]
local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ')
- local text = ('%s%s%s'):format(
+ local text = ('%s%s'):format(
name or '',
- infer.getInfer(rtn):view(),
- doc and doc.optional and '?' or ''
+ vm.getInfer(rtn):view()
)
if i == 1 then
returns[i] = (' -> %s'):format(text)
@@ -86,10 +83,7 @@ local function asDocFunction(source)
end
local returns = {}
for i, rtn in ipairs(source.returns) do
- local rtnText = ('%s%s'):format(
- infer.getInfer(rtn):view(),
- rtn.optional and '?' or ''
- )
+ local rtnText = vm.getInfer(rtn):view()
if i == 1 then
returns[#returns+1] = (' -> %s'):format(rtnText)
else
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
index 31036edd..16874101 100644
--- a/script/core/hover/table.lua
+++ b/script/core/hover/table.lua
@@ -1,7 +1,6 @@
local vm = require 'vm'
local util = require 'utility'
local config = require 'config'
-local infer = require 'vm.infer'
local await = require 'await'
local guide = require 'parser.guide'
@@ -16,22 +15,34 @@ local function formatKey(key)
return ('[%s]'):format(key)
end
-local function buildAsHash(keys, typeMap, literalMap, optMap, reachMax)
+---@param uri uri
+---@param keys string[]
+---@param nodeMap table<string, vm.node>
+---@param reachMax integer
+local function buildAsHash(uri, keys, nodeMap, reachMax)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local typeView = typeMap[key]
- local literalView = literalMap[key]
+ local node = nodeMap[key]
+ local isOptional = node:isOptional()
+ if isOptional then
+ node = node:copy()
+ node:removeOptional()
+ end
+ local ifr = vm.getInfer(node)
+ local typeView = ifr:view('unknown', uri)
+ local literalView = ifr:viewLiterals()
if literalView then
lines[#lines+1] = (' %s%s: %s = %s,'):format(
formatKey(key),
- optMap[key] and '?' or '',
+ isOptional and '?' or '',
typeView,
- literalView)
+ literalView
+ )
else
lines[#lines+1] = (' %s%s: %s,'):format(
formatKey(key),
- optMap[key] and '?' or '',
+ isOptional and '?' or '',
typeView
)
end
@@ -43,26 +54,40 @@ local function buildAsHash(keys, typeMap, literalMap, optMap, reachMax)
return table.concat(lines, '\n')
end
-local function buildAsConst(keys, typeMap, literalMap, optMap, reachMax)
+---@param uri uri
+---@param keys string[]
+---@param nodeMap table<string, vm.node>
+---@param reachMax integer
+local function buildAsConst(uri, keys, nodeMap, reachMax)
+ local literalMap = {}
+ for _, key in ipairs(keys) do
+ literalMap[key] = vm.getInfer(nodeMap[key]):viewLiterals()
+ end
table.sort(keys, function (a, b)
return tonumber(literalMap[a]) < tonumber(literalMap[b])
end)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local typeView = typeMap[key]
+ local node = nodeMap[key]
+ local isOptional = node:isOptional()
+ if isOptional then
+ node = node:copy()
+ node:removeOptional()
+ end
+ local typeView = vm.getInfer(node):view('unknown', uri)
local literalView = literalMap[key]
if literalView then
lines[#lines+1] = (' %s%s: %s = %s,'):format(
formatKey(key),
- optMap[key] and '?' or '',
+ isOptional and '?' or '',
typeView,
literalView
)
else
lines[#lines+1] = (' %s%s: %s,'):format(
formatKey(key),
- optMap[key] and '?' or '',
+ isOptional and '?' or '',
typeView
)
end
@@ -102,6 +127,19 @@ local function getKeyMap(fields)
if ta == 'boolean' then
return a == true
end
+ if ta == 'string' then
+ if a:sub(1, 1) == '_' then
+ if b:sub(1, 1) == '_' then
+ return a < b
+ else
+ return false
+ end
+ elseif b:sub(1, 1) == '_' then
+ return true
+ else
+ return a < b
+ end
+ end
return a < b
else
return tsa < tsb
@@ -110,48 +148,25 @@ local function getKeyMap(fields)
return keys, map
end
-local function getOptMap(fields, keyMap)
- local optMap = {}
- for _, field in ipairs(fields) do
- if field.type == 'doc.field' then
- if field.optional then
- local key = vm.getKeyName(field)
- if keyMap[key] then
- optMap[key] = true
- end
- end
- end
- if field.type == 'doc.type.field' then
- if field.optional then
- local key = vm.getKeyName(field)
- if keyMap[key] then
- optMap[key] = true
- end
- end
- end
- end
- return optMap
-end
-
---@async
-local function getInferMap(fields, keyMap)
- ---@type table<string, vm.infer>
- local inferMap = {}
+local function getNodeMap(fields, keyMap)
+ ---@type table<string, vm.node>
+ local nodeMap = {}
for _, field in ipairs(fields) do
local key = vm.getKeyName(field)
if not keyMap[key] then
goto CONTINUE
end
await.delay()
- local ifr = infer.getInfer(field)
- if inferMap[key] then
- inferMap[key] = inferMap[key]:merge(ifr)
+ local node = vm.compileNode(field)
+ if nodeMap[key] then
+ nodeMap[key]:merge(node)
else
- inferMap[key] = ifr
+ nodeMap[key] = node:copy()
end
::CONTINUE::
end
- return inferMap
+ return nodeMap
end
---@async
@@ -163,7 +178,7 @@ return function (source)
return nil
end
- for view in infer.getInfer(source):eachView() do
+ for view in vm.getInfer(source):eachView() do
if view == 'string'
or vm.isSubType(uri, view, 'string') then
return nil
@@ -184,19 +199,14 @@ return function (source)
end
end
- local optMap = getOptMap(fields, map)
- local inferMap = getInferMap(fields, map)
+ local nodeMap = getNodeMap(fields, map)
- local typeMap = {}
- local literalMap = {}
local isConsts = true
for i = 1, #keys do
await.delay()
local key = keys[i]
-
- typeMap[key] = inferMap[key]:view('unknown', uri)
- literalMap[key] = inferMap[key]:viewLiterals()
- if not tonumber(literalMap[key]) then
+ local literal = vm.getInfer(nodeMap[key]):viewLiterals()
+ if not tonumber(literal) then
isConsts = false
end
end
@@ -204,9 +214,9 @@ return function (source)
local result
if isConsts then
- result = buildAsConst(keys, typeMap, literalMap, optMap, reachMax)
+ result = buildAsConst(uri, keys, nodeMap, reachMax)
else
- result = buildAsHash(keys, typeMap, literalMap, optMap, reachMax)
+ result = buildAsHash(uri, keys, nodeMap, reachMax)
end
--if timeUp then
diff --git a/script/core/look-backward.lua b/script/core/look-backward.lua
index eea089bc..eeee6017 100644
--- a/script/core/look-backward.lua
+++ b/script/core/look-backward.lua
@@ -2,7 +2,8 @@
local m = {}
--- 是否是空白符
----@param inline boolean # 必须在同一行中(排除换行符)
+---@param char string
+---@param inline? boolean # 必须在同一行中(排除换行符)
function m.isSpace(char, inline)
if inline then
if char == ' '
@@ -21,7 +22,9 @@ function m.isSpace(char, inline)
end
--- 跳过空白符
----@param inline boolean # 必须在同一行中(排除换行符)
+---@param text string
+---@param offset integer
+---@param inline? boolean # 必须在同一行中(排除换行符)
function m.skipSpace(text, offset, inline)
for i = offset, 1, -1 do
local char = text:sub(i, i)
diff --git a/script/core/matchkey.lua b/script/core/matchkey.lua
index 3c6a54a8..4db9d764 100644
--- a/script/core/matchkey.lua
+++ b/script/core/matchkey.lua
@@ -59,7 +59,7 @@ end
---@param input string
---@param other string
----@param fast boolean
+---@param fast? boolean
---@return boolean isMatch
---@return number deviation
return function (input, other, fast)
diff --git a/script/core/rangeformatting.lua b/script/core/rangeformatting.lua
index ccf2d21f..f64e9cda 100644
--- a/script/core/rangeformatting.lua
+++ b/script/core/rangeformatting.lua
@@ -4,7 +4,7 @@ local log = require("log")
local converter = require("proto.converter")
return function(uri, range, options)
- local text = files.getText(uri)
+ local text = files.getOriginText(uri)
local status, formattedText, startLine, endLine = codeFormat.range_format(
uri, text, range.start.line, range["end"].line, options)
diff --git a/script/core/rename.lua b/script/core/rename.lua
index ec21e87c..7599fad6 100644
--- a/script/core/rename.lua
+++ b/script/core/rename.lua
@@ -3,7 +3,6 @@ local vm = require 'vm'
local util = require 'utility'
local findSource = require 'core.find-source'
local guide = require 'parser.guide'
-local globalMgr = require 'vm.global-manager'
local Forcing
@@ -191,7 +190,7 @@ end
---@async
local function ofGlobal(source, newname, callback)
local key = guide.getKeyName(source)
- local global = globalMgr.getGlobal('variable', key)
+ local global = vm.getGlobal('variable', key)
if not global then
return
end
@@ -214,7 +213,7 @@ end
---@async
local function ofDocTypeName(source, newname, callback)
local oldname = source[1]
- local global = globalMgr.getGlobal('type', oldname)
+ local global = vm.getGlobal('type', oldname)
if not global then
return
end
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index 568bb222..33449013 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -5,7 +5,6 @@ local vm = require 'vm'
local util = require 'utility'
local guide = require 'parser.guide'
local converter = require 'proto.converter'
-local infer = require 'vm.infer'
local config = require 'config'
local linkedTable = require 'linked-table'
@@ -16,8 +15,24 @@ local Care = util.switch()
if not options.variable then
return
end
- local isLib = vm.isGlobalLibraryName(source[1])
- local isFunc = infer.getInfer(source):hasFunction()
+
+ local name = source[1]
+ local isLib = options.libGlobals[name]
+ if isLib == nil then
+ isLib = false
+ local global = vm.getGlobal('variable', name)
+ if global then
+ local uri = guide.getUri(source)
+ for _, set in ipairs(global:getSets(uri)) do
+ if vm.isMetaFile(guide.getUri(set)) then
+ isLib = true
+ break
+ end
+ end
+ end
+ options.libGlobals[name] = isLib
+ end
+ local isFunc = vm.getInfer(source):hasFunction()
local type = isFunc and define.TokenTypes['function'] or define.TokenTypes.variable
local modifier = isLib and define.TokenModifiers.defaultLibrary or define.TokenModifiers.static
@@ -66,7 +81,7 @@ local Care = util.switch()
return
end
end
- if infer.getInfer(source):hasFunction() then
+ if vm.getInfer(source):hasFunction() then
results[#results+1] = {
start = source.start,
finish = source.finish,
@@ -165,27 +180,23 @@ local Care = util.switch()
-- 5. Class declaration
-- only search this local
if loc.bindDocs then
- for i = #loc.bindDocs, 1, -1 do
- local doc = loc.bindDocs[i]
- if doc.type == 'doc.type' then
- break
- end
- if doc.type == "doc.class" and doc.bindSources then
- for _, src in ipairs(doc.bindSources) do
- if src == loc then
- results[#results+1] = {
- start = source.start,
- finish = source.finish,
- type = define.TokenTypes.class,
- }
- return
- end
+ local isParam = source.parent.type == 'funcargs'
+ or source.parent.type == 'in'
+ if not isParam then
+ for _, doc in ipairs(loc.bindDocs) do
+ if doc.type == 'doc.class' then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.class,
+ }
+ return
end
end
end
end
-- 6. References to other functions
- if infer.getInfer(loc):hasFunction() then
+ if vm.getInfer(loc):hasFunction() then
results[#results+1] = {
start = source.start,
finish = source.finish,
@@ -656,6 +667,14 @@ local Care = util.switch()
type = define.TokenTypes.keyword,
}
end)
+ : case 'doc.cast.name'
+ : call(function (source, options, results)
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.variable,
+ }
+ end)
local function buildTokens(uri, results)
local tokens = {}
@@ -773,24 +792,25 @@ end
---@async
return function (uri, start, finish)
+ local results = {}
if not config.get(uri, 'Lua.semantic.enable') then
- return nil
+ return results
end
local state = files.getState(uri)
if not state then
- return nil
+ return results
end
local options = {
uri = uri,
state = state,
text = files.getText(uri),
+ libGlobals = {},
variable = config.get(uri, 'Lua.semantic.variable'),
annotation = config.get(uri, 'Lua.semantic.annotation'),
keyword = config.get(uri, 'Lua.semantic.keyword'),
}
- local results = {}
guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async
Care(source.type, source, options, results)
await.delay()
@@ -798,27 +818,26 @@ return function (uri, start, finish)
for _, comm in ipairs(state.comms) do
if start <= comm.start and comm.finish <= finish then
- if comm.type == 'comment.short' then
- local head = comm.text:match '^%-%s*[@|]'
- if head then
- results[#results+1] = {
- start = comm.start,
- finish = comm.start + #head + 1,
- type = define.TokenTypes.comment,
- }
- results[#results+1] = {
- start = comm.start + #head + 1,
- finish = comm.start + #head + 2 + #comm.text:match('%S*', #head + 1),
- type = define.TokenTypes.keyword,
- modifieres = define.TokenModifiers.documentation,
- }
+ local headPos = (comm.type == 'comment.short' and comm.text:match '^%-%s*[@|]()')
+ or (comm.type == 'comment.long' and comm.text:match '^@()')
+ if headPos then
+ local atPos
+ if comm.type == 'comment.short' then
+ atPos = headPos + 2
else
- results[#results+1] = {
- start = comm.start,
- finish = comm.finish,
- type = define.TokenTypes.comment,
- }
+ atPos = headPos + #comm.mark
end
+ results[#results+1] = {
+ start = comm.start,
+ finish = comm.start + atPos - 2,
+ type = define.TokenTypes.comment,
+ }
+ results[#results+1] = {
+ start = comm.start + atPos - 2,
+ finish = comm.start + atPos - 1 + #comm.text:match('%S*', headPos),
+ type = define.TokenTypes.keyword,
+ modifieres = define.TokenModifiers.documentation,
+ }
else
results[#results+1] = {
start = comm.start,
@@ -830,7 +849,7 @@ return function (uri, start, finish)
end
if #results == 0 then
- return {}
+ return results
end
results = solveMultilineAndOverlapping(state, results)
diff --git a/script/core/signature.lua b/script/core/signature.lua
index ab7268dd..025e70b7 100644
--- a/script/core/signature.lua
+++ b/script/core/signature.lua
@@ -41,6 +41,9 @@ end
---@async
local function makeOneSignature(source, oop, index)
local label = hoverLabel(source, oop)
+ if not label then
+ return nil
+ end
-- 去掉返回值
label = label:gsub('%s*->.+', '')
local params = {}
diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua
index 92f81997..d8434c8c 100644
--- a/script/core/type-definition.lua
+++ b/script/core/type-definition.lua
@@ -3,7 +3,6 @@ local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
local guide = require 'parser.guide'
-local infer = require 'vm.infer'
local rpath = require 'workspace.require-path'
local function sortResults(results)
diff --git a/script/doctor.lua b/script/doctor.lua
index 91a7e4b8..87cdcfcb 100644
--- a/script/doctor.lua
+++ b/script/doctor.lua
@@ -175,6 +175,9 @@ m.snapshot = private(function ()
exclude[o] = true
end
end
+ ---@generic T
+ ---@param o T
+ ---@return T
local function private(o)
if not o then
return nil
diff --git a/script/encoder/init.lua b/script/encoder/init.lua
index 0011265a..3c8a58e0 100644
--- a/script/encoder/init.lua
+++ b/script/encoder/init.lua
@@ -10,9 +10,9 @@ local utf16be = utf16('be', utf8.codepoint '�')
local m = {}
---@param encoding encoder.encoding
----@param s string
----@param i integer
----@param j integer
+---@param s string
+---@param i? integer
+---@param j? integer
function m.len(encoding, s, i, j)
i = i or 1
j = j or #s
@@ -33,9 +33,9 @@ function m.len(encoding, s, i, j)
end
---@param encoding encoder.encoding
----@param s string
----@param n integer
----@param i integer
+---@param s string
+---@param n integer
+---@param i? integer
function m.offset(encoding, s, n, i)
i = i or 1
if encoding == 'utf16'
diff --git a/script/files.lua b/script/files.lua
index d16474fd..22c9ae31 100644
--- a/script/files.lua
+++ b/script/files.lua
@@ -165,8 +165,8 @@ end
--- 设置文件文本
---@param uri uri
---@param text string
----@param isTrust boolean
----@param callback function
+---@param isTrust? boolean
+---@param callback? function
function m.setText(uri, text, isTrust, callback)
if not text then
return
diff --git a/script/fs-utility.lua b/script/fs-utility.lua
index c845c769..08aae98a 100644
--- a/script/fs-utility.lua
+++ b/script/fs-utility.lua
@@ -281,12 +281,8 @@ local function fsIsDirectory(path, option)
if path.type == 'dummy' then
return path:isDirectory()
end
- local suc, res = pcall(fs.is_directory, path)
- if not suc then
- option.err[#option.err+1] = res
- return false
- end
- return res
+ local status = fs.symlink_status(path):type()
+ return status == 'directory'
end
local function fsPairs(path, option)
@@ -616,9 +612,10 @@ end
function m.scanDirectory(dir, callback)
for fullpath in fs.pairs(dir) do
- if fs.is_directory(fullpath) then
+ local status = fs.symlink_status(fullpath):type()
+ if status == 'directory' then
m.scanDirectory(fullpath, callback)
- else
+ elseif status == 'regular' then
callback(fullpath)
end
end
diff --git a/script/glob/gitignore.lua b/script/glob/gitignore.lua
index 09be1415..4dad2747 100644
--- a/script/glob/gitignore.lua
+++ b/script/glob/gitignore.lua
@@ -163,7 +163,7 @@ function mt:getRelativePath(path)
return path
end
----@param callback async fun()
+---@param callback async fun(path: string)
---@async
function mt:scan(path, callback)
local files = {}
diff --git a/script/jsonc.lua b/script/jsonc.lua
new file mode 100644
index 00000000..0361d99b
--- /dev/null
+++ b/script/jsonc.lua
@@ -0,0 +1,603 @@
+local type = type
+local next = next
+local error = error
+local tonumber = tonumber
+local tostring = tostring
+local table_concat = table.concat
+local table_sort = table.sort
+local string_char = string.char
+local string_byte = string.byte
+local string_find = string.find
+local string_match = string.match
+local string_gsub = string.gsub
+local string_sub = string.sub
+local string_rep = string.rep
+local string_format = string.format
+local setmetatable = setmetatable
+local getmetatable = getmetatable
+local huge = math.huge
+local tiny = -huge
+
+local utf8_char
+local math_type
+
+if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then
+ local math_floor = math.floor
+ function utf8_char(c)
+ if c <= 0x7f then
+ return string_char(c)
+ elseif c <= 0x7ff then
+ return string_char(math_floor(c / 64) + 192, c % 64 + 128)
+ elseif c <= 0xffff then
+ return string_char(
+ math_floor(c / 4096) + 224,
+ math_floor(c % 4096 / 64) + 128,
+ c % 64 + 128
+ )
+ elseif c <= 0x10ffff then
+ return string_char(
+ math_floor(c / 262144) + 240,
+ math_floor(c % 262144 / 4096) + 128,
+ math_floor(c % 4096 / 64) + 128,
+ c % 64 + 128
+ )
+ end
+ error(string.format("invalid UTF-8 code '%x'", c))
+ end
+ function math_type(v)
+ if v >= -2147483648 and v <= 2147483647 and math_floor(v) == v then
+ return "integer"
+ end
+ return "float"
+ end
+else
+ utf8_char = utf8.char
+ math_type = math.type
+end
+
+local json = {}
+
+json.supportSparseArray = true
+
+local objectMt = {}
+
+function json.createEmptyObject()
+ return setmetatable({}, objectMt)
+end
+
+function json.isObject(t)
+ if t[1] ~= nil then
+ return false
+ end
+ return next(t) ~= nil or getmetatable(t) == objectMt
+end
+
+if debug and debug.upvalueid then
+ -- Generate a lightuserdata
+ json.null = debug.upvalueid(json.createEmptyObject, 1)
+else
+ json.null = function() end
+end
+
+-- json.encode --
+
+local statusVisited
+local statusBuilder
+local statusDep
+local statusOpt
+
+local defaultOpt = {
+ newline = "",
+ indent = "",
+}
+defaultOpt.__index = defaultOpt
+
+local encode_map = {}
+
+local encode_escape_map = {
+ [ "\"" ] = "\\\"",
+ [ "\\" ] = "\\\\",
+ [ "/" ] = "\\/",
+ [ "\b" ] = "\\b",
+ [ "\f" ] = "\\f",
+ [ "\n" ] = "\\n",
+ [ "\r" ] = "\\r",
+ [ "\t" ] = "\\t",
+}
+
+local decode_escape_set = {}
+local decode_escape_map = {}
+for k, v in next, encode_escape_map do
+ decode_escape_map[v] = k
+ decode_escape_set[string_byte(v, 2)] = true
+end
+
+for i = 0, 31 do
+ local c = string_char(i)
+ if not encode_escape_map[c] then
+ encode_escape_map[c] = string_format("\\u%04x", i)
+ end
+end
+
+encode_map["nil"] = function ()
+ return "null"
+end
+
+local function encode_string(v)
+ return string_gsub(v, '[%z\1-\31\\"]', encode_escape_map)
+end
+
+local function convertreal(v)
+ local g = string_format('%.16g', v)
+ if tonumber(g) == v then
+ return g
+ end
+ return string_format('%.17g', v)
+end
+
+if string_match(tostring(1/2), "%p") == "," then
+ local _convertreal = convertreal
+ function convertreal(v)
+ return string_gsub(_convertreal(v), ',', '.')
+ end
+end
+
+function encode_map.number(v)
+ if v ~= v or v <= tiny or v >= huge then
+ error("unexpected number value '" .. tostring(v) .. "'")
+ end
+ if math_type(v) == "integer" then
+ return string_format('%d', v)
+ end
+ return convertreal(v)
+end
+
+function encode_map.boolean(v)
+ if v then
+ return "true"
+ else
+ return "false"
+ end
+end
+
+local function encode_unexpected(v)
+ if v == json.null then
+ return "null"
+ else
+ error("unexpected type '"..type(v).."'")
+ end
+end
+encode_map[ "function" ] = encode_unexpected
+encode_map[ "userdata" ] = encode_unexpected
+encode_map[ "thread" ] = encode_unexpected
+
+local function encode_newline()
+ statusBuilder[#statusBuilder+1] = statusOpt.newline..string_rep(statusOpt.indent, statusDep)
+end
+
+local function encode(v)
+ local res = encode_map[type(v)](v)
+ statusBuilder[#statusBuilder+1] = res
+end
+
+function encode_map.string(v)
+ statusBuilder[#statusBuilder+1] = '"'
+ statusBuilder[#statusBuilder+1] = encode_string(v)
+ return '"'
+end
+
+function encode_map.table(t)
+ local first_val = next(t)
+ if first_val == nil then
+ if getmetatable(t) == objectMt then
+ return "{}"
+ else
+ return "[]"
+ end
+ end
+ if statusVisited[t] then
+ error("circular reference")
+ end
+ statusVisited[t] = true
+ if type(first_val) == 'string' then
+ local key = {}
+ for k in next, t do
+ if type(k) ~= "string" then
+ error("invalid table: mixed or invalid key types")
+ end
+ key[#key+1] = k
+ end
+ table_sort(key)
+ statusBuilder[#statusBuilder+1] = "{"
+ statusDep = statusDep + 1
+ encode_newline()
+ local k = key[1]
+ statusBuilder[#statusBuilder+1] = '"'
+ statusBuilder[#statusBuilder+1] = encode_string(k)
+ statusBuilder[#statusBuilder+1] = '": '
+ encode(t[k])
+ for i = 2, #key do
+ local k = key[i]
+ statusBuilder[#statusBuilder+1] = ","
+ encode_newline()
+ statusBuilder[#statusBuilder+1] = '"'
+ statusBuilder[#statusBuilder+1] = encode_string(k)
+ statusBuilder[#statusBuilder+1] = '": '
+ encode(t[k])
+ end
+ statusDep = statusDep - 1
+ encode_newline()
+ statusVisited[t] = nil
+ return "}"
+ elseif json.supportSparseArray then
+ local max = 0
+ for k in next, t do
+ if math_type(k) ~= "integer" or k <= 0 then
+ error("invalid table: mixed or invalid key types")
+ end
+ if max < k then
+ max = k
+ end
+ end
+ statusBuilder[#statusBuilder+1] = "["
+ statusDep = statusDep + 1
+ encode_newline()
+ encode(t[1])
+ for i = 2, max do
+ statusBuilder[#statusBuilder+1] = ","
+ encode_newline()
+ encode(t[i])
+ end
+ statusDep = statusDep - 1
+ encode_newline()
+ statusVisited[t] = nil
+ return "]"
+ else
+ if t[1] == nil then
+ error("invalid table: mixed or invalid key types")
+ end
+ statusBuilder[#statusBuilder+1] = "["
+ statusDep = statusDep + 1
+ encode_newline()
+ encode(t[1])
+ local count = 2
+ while t[count] ~= nil do
+ statusBuilder[#statusBuilder+1] = ","
+ encode_newline()
+ encode(t[count])
+ count = count + 1
+ end
+ if next(t, count-1) ~= nil then
+ error("invalid table: mixed or invalid key types")
+ end
+ statusDep = statusDep - 1
+ encode_newline()
+ statusVisited[t] = nil
+ return "]"
+ end
+end
+
+function json.encode(v, option)
+ statusVisited = {}
+ statusBuilder = {}
+ statusDep = 0
+ statusOpt = option and setmetatable(option, defaultOpt) or defaultOpt
+ encode(v)
+ return table_concat(statusBuilder)
+end
+
+-- json.decode --
+
+local statusBuf
+local statusPos
+local statusTop
+local statusAry = {}
+local statusRef = {}
+
+local function find_line()
+ local line = 1
+ local pos = 1
+ while true do
+ local f, _, nl1, nl2 = string_find(statusBuf, '([\n\r])([\n\r]?)', pos)
+ if not f then
+ return line, statusPos - pos + 1
+ end
+ local newpos = f + ((nl1 == nl2 or nl2 == '') and 1 or 2)
+ if newpos > statusPos then
+ return line, statusPos - pos + 1
+ end
+ pos = newpos
+ line = line + 1
+ end
+end
+
+local function decode_error(msg)
+ error(string_format("ERROR: %s at line %d col %d", msg, find_line()), 2)
+end
+
+local function get_word()
+ return string_match(statusBuf, "^[^ \t\r\n%]},]*", statusPos)
+end
+
+local function skip_comment(b)
+ if b ~= 47 --[[ '/' ]] then
+ return
+ end
+ local c = string_byte(statusBuf, statusPos+1)
+ if c == 42 --[[ '*' ]] then
+ -- block comment
+ local pos = string_find(statusBuf, "*/", statusPos)
+ if pos then
+ statusPos = pos + 2
+ else
+ statusPos = #statusBuf + 1
+ end
+ return true
+ elseif c == 47 --[[ '/' ]] then
+ -- line comment
+ local pos = string_find(statusBuf, "[\r\n]", statusPos)
+ if pos then
+ statusPos = pos
+ else
+ statusPos = #statusBuf + 1
+ end
+ return true
+ end
+end
+
+local function next_byte()
+ local pos = string_find(statusBuf, "[^ \t\r\n]", statusPos)
+ if pos then
+ statusPos = pos
+ local b = string_byte(statusBuf, pos)
+ if not skip_comment(b) then
+ return b
+ end
+ return next_byte()
+ end
+ return -1
+end
+
+local function decode_unicode_surrogate(s1, s2)
+ return utf8_char(0x10000 + (tonumber(s1, 16) - 0xd800) * 0x400 + (tonumber(s2, 16) - 0xdc00))
+end
+
+local function decode_unicode_escape(s)
+ return utf8_char(tonumber(s, 16))
+end
+
+local function decode_string()
+ local has_unicode_escape = false
+ local has_escape = false
+ local i = statusPos + 1
+ while true do
+ i = string_find(statusBuf, '[%z\1-\31\\"]', i)
+ if not i then
+ decode_error "expected closing quote for string"
+ end
+ local x = string_byte(statusBuf, i)
+ if x < 32 then
+ statusPos = i
+ decode_error "control character in string"
+ end
+ if x == 34 --[[ '"' ]] then
+ local s = string_sub(statusBuf, statusPos + 1, i - 1)
+ if has_unicode_escape then
+ s = string_gsub(string_gsub(s
+ , "\\u([dD][89aAbB]%x%x)\\u([dD][c-fC-F]%x%x)", decode_unicode_surrogate)
+ , "\\u(%x%x%x%x)", decode_unicode_escape)
+ end
+ if has_escape then
+ s = string_gsub(s, "\\.", decode_escape_map)
+ end
+ statusPos = i + 1
+ return s
+ end
+ --assert(x == 92 --[[ "\\" ]])
+ local nx = string_byte(statusBuf, i+1)
+ if nx == 117 --[[ "u" ]] then
+ if not string_match(statusBuf, "^%x%x%x%x", i+2) then
+ statusPos = i
+ decode_error "invalid unicode escape in string"
+ end
+ has_unicode_escape = true
+ i = i + 6
+ else
+ if not decode_escape_set[nx] then
+ statusPos = i
+ decode_error("invalid escape char '" .. (nx and string_char(nx) or "<eol>") .. "' in string")
+ end
+ has_escape = true
+ i = i + 2
+ end
+ end
+end
+
+local function decode_number()
+ local num, c = string_match(statusBuf, '^([0-9]+%.?[0-9]*)([eE]?)', statusPos)
+ if not num or string_byte(num, -1) == 0x2E --[[ "." ]] then
+ decode_error("invalid number '" .. get_word() .. "'")
+ end
+ if c ~= '' then
+ num = string_match(statusBuf, '^([^eE]*[eE][-+]?[0-9]+)[ \t\r\n%]},/]', statusPos)
+ if not num then
+ decode_error("invalid number '" .. get_word() .. "'")
+ end
+ end
+ statusPos = statusPos + #num
+ return tonumber(num)
+end
+
+local function decode_number_zero()
+ local num, c = string_match(statusBuf, '^(.%.?[0-9]*)([eE]?)', statusPos)
+ if not num or string_byte(num, -1) == 0x2E --[[ "." ]] or string_match(statusBuf, '^.[0-9]+', statusPos) then
+ decode_error("invalid number '" .. get_word() .. "'")
+ end
+ if c ~= '' then
+ num = string_match(statusBuf, '^([^eE]*[eE][-+]?[0-9]+)[ \t\r\n%]},/]', statusPos)
+ if not num then
+ decode_error("invalid number '" .. get_word() .. "'")
+ end
+ end
+ statusPos = statusPos + #num
+ return tonumber(num)
+end
+
+local function decode_number_negative()
+ statusPos = statusPos + 1
+ local c = string_byte(statusBuf, statusPos)
+ if c then
+ if c == 0x30 then
+ return -decode_number_zero()
+ elseif c > 0x30 and c < 0x3A then
+ return -decode_number()
+ end
+ end
+ decode_error("invalid number '" .. get_word() .. "'")
+end
+
+local function decode_true()
+ if string_sub(statusBuf, statusPos, statusPos+3) ~= "true" then
+ decode_error("invalid literal '" .. get_word() .. "'")
+ end
+ statusPos = statusPos + 4
+ return true
+end
+
+local function decode_false()
+ if string_sub(statusBuf, statusPos, statusPos+4) ~= "false" then
+ decode_error("invalid literal '" .. get_word() .. "'")
+ end
+ statusPos = statusPos + 5
+ return false
+end
+
+local function decode_null()
+ if string_sub(statusBuf, statusPos, statusPos+3) ~= "null" then
+ decode_error("invalid literal '" .. get_word() .. "'")
+ end
+ statusPos = statusPos + 4
+ return json.null
+end
+
+local function decode_array()
+ statusPos = statusPos + 1
+ local res = {}
+ local chr = next_byte()
+ if chr == 93 --[[ ']' ]] then
+ statusPos = statusPos + 1
+ return res
+ end
+ statusTop = statusTop + 1
+ statusAry[statusTop] = true
+ statusRef[statusTop] = res
+ return res
+end
+
+local function decode_object()
+ statusPos = statusPos + 1
+ local res = {}
+ local chr = next_byte()
+ if chr == 125 --[[ ']' ]] then
+ statusPos = statusPos + 1
+ return json.createEmptyObject()
+ end
+ statusTop = statusTop + 1
+ statusAry[statusTop] = false
+ statusRef[statusTop] = res
+ return res
+end
+
+local decode_uncompleted_map = {
+ [ string_byte '"' ] = decode_string,
+ [ string_byte "0" ] = decode_number_zero,
+ [ string_byte "1" ] = decode_number,
+ [ string_byte "2" ] = decode_number,
+ [ string_byte "3" ] = decode_number,
+ [ string_byte "4" ] = decode_number,
+ [ string_byte "5" ] = decode_number,
+ [ string_byte "6" ] = decode_number,
+ [ string_byte "7" ] = decode_number,
+ [ string_byte "8" ] = decode_number,
+ [ string_byte "9" ] = decode_number,
+ [ string_byte "-" ] = decode_number_negative,
+ [ string_byte "t" ] = decode_true,
+ [ string_byte "f" ] = decode_false,
+ [ string_byte "n" ] = decode_null,
+ [ string_byte "[" ] = decode_array,
+ [ string_byte "{" ] = decode_object,
+}
+local function unexpected_character()
+ decode_error("unexpected character '" .. string_sub(statusBuf, statusPos, statusPos) .. "'")
+end
+local function unexpected_eol()
+ decode_error("unexpected character '<eol>'")
+end
+
+local decode_map = {}
+for i = 0, 255 do
+ decode_map[i] = decode_uncompleted_map[i] or unexpected_character
+end
+decode_map[-1] = unexpected_eol
+
+local function decode()
+ return decode_map[next_byte()]()
+end
+
+local function decode_item()
+ local top = statusTop
+ local ref = statusRef[top]
+ if statusAry[top] then
+ ref[#ref+1] = decode()
+ else
+ local key = decode_string()
+ if next_byte() ~= 58 --[[ ':' ]] then
+ decode_error "expected ':'"
+ end
+ statusPos = statusPos + 1
+ ref[key] = decode()
+ end
+ if top == statusTop then
+ repeat
+ local chr = next_byte(); statusPos = statusPos + 1
+ if chr == 44 --[[ "," ]] then
+ local c = next_byte()
+ if statusAry[statusTop] then
+ if c ~= 93 --[[ "]" ]] then return end
+ else
+ if c ~= 125 --[[ "}" ]] then return end
+ end
+ statusPos = statusPos + 1
+ else
+ if statusAry[statusTop] then
+ if chr ~= 93 --[[ "]" ]] then decode_error "expected ']' or ','" end
+ else
+ if chr ~= 125 --[[ "}" ]] then decode_error "expected '}' or ','" end
+ end
+ end
+ statusTop = statusTop - 1
+ until statusTop == 0
+ end
+end
+
+function json.decode(str)
+ if type(str) ~= "string" then
+ error("expected argument of type string, got " .. type(str))
+ end
+ statusBuf = str
+ statusPos = 1
+ statusTop = 0
+ if next_byte() == -1 then
+ return json.null
+ end
+ local res = decode()
+ while statusTop > 0 do
+ decode_item()
+ end
+ if string_find(statusBuf, "[^ \t\r\n]", statusPos) then
+ decode_error "trailing garbage"
+ end
+ return res
+end
+
+return json
diff --git a/script/language.lua b/script/language.lua
index 771dc948..22546fb8 100644
--- a/script/language.lua
+++ b/script/language.lua
@@ -6,7 +6,9 @@ local function supportLanguage()
local list = {}
for path in fs.pairs(ROOT / 'locale') do
if fs.is_directory(path) then
- list[#list+1] = path:filename():string():lower()
+ local id = path:filename():string():lower()
+ list[#list+1] = id
+ list[id] = true
end
end
return list
diff --git a/script/log.lua b/script/log.lua
index 597bdc4e..6cb865c3 100644
--- a/script/log.lua
+++ b/script/log.lua
@@ -85,7 +85,10 @@ function m.warn(...)
end
function m.error(...)
- return pushLog('error', ...)
+ -- Don't use tail calls,
+ -- Otherwise, the count of `debug.getinfo` will be wrong
+ local msg = pushLog('error', ...)
+ return msg
end
function m.raw(thd, level, msg, source, currentline, clock)
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 0ece65fc..06169b09 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -16,6 +16,7 @@ local type = type
---@field uri uri
---@field start integer
---@field finish integer
+---@field range integer
---@field effect integer
---@field attrs string[]
---@field specials parser.object[]
@@ -56,6 +57,14 @@ local type = type
---@field init parser.object
---@field step parser.object
---@field redundant { max: integer, passed: integer }
+---@field filter parser.object
+---@field loc parser.object
+---@field keyword integer[]
+---@field casts parser.object[]
+---@field mode? '+' | '-'
+---@field hasGoTo? true
+---@field hasReturn? true
+---@field hasBreak? true
---@field _root parser.object
---@class guide
@@ -71,6 +80,7 @@ local blockTypes = {
['repeat'] = true,
['do'] = true,
['function'] = true,
+ ['if'] = true,
['ifblock'] = true,
['elseblock'] = true,
['elseifblock'] = true,
@@ -141,6 +151,9 @@ local childMap = {
['doc.see'] = {'name', 'field'},
['doc.version'] = {'#versions'},
['doc.diagnostic'] = {'#names'},
+ ['doc.as'] = {'as'},
+ ['doc.cast'] = {'loc', '#casts'},
+ ['doc.cast.block'] = {'extends'},
}
---@type table<string, fun(obj: parser.object, list: parser.object[])>
@@ -393,6 +406,7 @@ function m.getRoot(obj)
end
local parent = obj.parent
if not parent then
+ log.error('Can not find out root:', obj.type)
return nil
end
obj = parent
@@ -413,6 +427,7 @@ function m.getUri(obj)
return ''
end
+---@return parser.object?
function m.getENV(source, start)
if not start then
start = 1
@@ -446,19 +461,17 @@ function m.getFunctionVarArgs(func)
end
--- 获取指定区块中可见的局部变量
----@param block table
----@param name string {comment = '变量名'}
----@param pos integer {comment = '可见位置'}
-function m.getLocal(block, name, pos)
- block = m.getBlock(block)
- for _ = 1, 10000 do
- if not block then
- return nil
- end
- local locals = block.locals
- local res
+---@param source parser.object
+---@param name string # 变量名
+---@param pos integer # 可见位置
+---@return parser.object?
+function m.getLocal(source, name, pos)
+ local root = m.getRoot(source)
+ local res
+ m.eachSourceContain(root, pos, function (src)
+ local locals = src.locals
if not locals then
- goto CONTINUE
+ return
end
for i = 1, #locals do
local loc = locals[i]
@@ -471,13 +484,8 @@ function m.getLocal(block, name, pos)
end
end
end
- if res then
- return res, res
- end
- ::CONTINUE::
- block = m.getParentBlock(block)
- end
- error('guide.getLocal overstack')
+ end)
+ return res
end
--- 获取指定区块中所有的可见局部变量名称
@@ -602,6 +610,9 @@ local function addChilds(list, obj)
end
--- 遍历所有包含position的source
+---@param ast parser.object
+---@param position integer
+---@param callback fun(src: parser.object)
function m.eachSourceContain(ast, position, callback)
local list = { ast }
local mark = {}
@@ -922,6 +933,7 @@ function m.getKeyNameOfLiteral(obj)
end
end
+---@return string?
function m.getKeyName(obj)
if not obj then
return nil
@@ -1027,8 +1039,6 @@ function m.getKeyType(obj)
return type(obj.field[1])
elseif tp == 'doc.type.field' then
return type(obj.name[1])
- elseif tp == 'dummy' then
- return 'string'
end
if tp == 'doc.field.name' then
return type(obj[1])
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 5a2e1d09..d8e31950 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -2,10 +2,11 @@ local m = require 'lpeglabel'
local re = require 'parser.relabel'
local guide = require 'parser.guide'
local parser = require 'parser.newparser'
+local util = require 'utility'
local TokenTypes, TokenStarts, TokenFinishs, TokenContents, TokenMarks
local Ci, Offset, pushWarning, NextComment, Lines
-local parseType
+local parseType, parseTypeUnit
---@type any
local Parser = re.compile([[
Main <- (Token / Sp)*
@@ -52,6 +53,7 @@ Symbol <- ({} {
/ '...'
/ '['
/ ']'
+ / '-' !'-'
} {})
-> Symbol
]], {
@@ -124,6 +126,8 @@ Symbol <- ({} {
---@class parser.object
---@field literal boolean
---@field signs parser.object[]
+---@field originalComment parser.object
+---@field as? parser.object
local function trim(str)
return str:match '^%s*(%S+)%s*$'
@@ -336,104 +340,6 @@ local function parseSigns(parent)
return signs
end
-local function parseClass(parent)
- local result = {
- type = 'doc.class',
- parent = parent,
- fields = {},
- }
- result.class = parseName('doc.class.name', result)
- if not result.class then
- pushWarning {
- type = 'LUADOC_MISS_CLASS_NAME',
- start = getFinish(),
- finish = getFinish(),
- }
- return nil
- end
- result.start = getStart()
- result.finish = getFinish()
- result.signs = parseSigns(result)
- if not checkToken('symbol', ':', 1) then
- return result
- end
- nextToken()
-
- result.extends = {}
-
- while true do
- local extend = parseName('doc.extends.name', result)
- or parseTable(result)
- if not extend then
- pushWarning {
- type = 'LUADOC_MISS_CLASS_EXTENDS_NAME',
- start = getFinish(),
- finish = getFinish(),
- }
- return result
- end
- result.extends[#result.extends+1] = extend
- result.finish = getFinish()
- if not checkToken('symbol', ',', 1) then
- break
- end
- nextToken()
- end
- return result
-end
-
-local function parseTypeUnitArray(parent, node)
- if not checkToken('symbol', '[]', 1) then
- return nil
- end
- nextToken()
- local result = {
- type = 'doc.type.array',
- start = node.start,
- finish = getFinish(),
- node = node,
- parent = parent,
- }
- node.parent = result
- return result
-end
-
-local function parseTypeUnitSign(parent, node)
- if not checkToken('symbol', '<', 1) then
- return nil
- end
- nextToken()
- local result = {
- type = 'doc.type.sign',
- start = node.start,
- finish = getFinish(),
- node = node,
- parent = parent,
- signs = {},
- }
- node.parent = result
- while true do
- local sign = parseType(result)
- if not sign then
- pushWarning {
- type = 'LUA_DOC_MISS_SIGN',
- start = getFinish(),
- finish = getFinish(),
- }
- break
- end
- result.signs[#result.signs+1] = sign
- if checkToken('symbol', ',', 1) then
- nextToken()
- else
- break
- end
- end
- nextSymbolOrError '>'
- result.finish = getFinish()
- return result
-end
-
local function parseDots(tp, parent)
if not checkToken('symbol', '...', 1) then
return
@@ -527,8 +433,6 @@ local function parseTypeUnitFunction(parent)
return typeUnit
end
-local parseTypeUnit
-
local function parseFunction(parent)
local _, content = peekToken()
if content == 'async' then
@@ -551,6 +455,58 @@ local function parseFunction(parent)
end
end
+local function parseTypeUnitArray(parent, node)
+ if not checkToken('symbol', '[]', 1) then
+ return nil
+ end
+ nextToken()
+ local result = {
+ type = 'doc.type.array',
+ start = node.start,
+ finish = getFinish(),
+ node = node,
+ parent = parent,
+ }
+ node.parent = result
+ return result
+end
+
+local function parseTypeUnitSign(parent, node)
+ if not checkToken('symbol', '<', 1) then
+ return nil
+ end
+ nextToken()
+ local result = {
+ type = 'doc.type.sign',
+ start = node.start,
+ finish = getFinish(),
+ node = node,
+ parent = parent,
+ signs = {},
+ }
+ node.parent = result
+ while true do
+ local sign = parseType(result)
+ if not sign then
+ pushWarning {
+ type = 'LUA_DOC_MISS_SIGN',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ break
+ end
+ result.signs[#result.signs+1] = sign
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ else
+ break
+ end
+ end
+ nextSymbolOrError '>'
+ result.finish = getFinish()
+ return result
+end
+
local function parseString(parent)
local tp, content = peekToken()
if not tp or tp ~= 'string' then
@@ -709,6 +665,10 @@ function parseType(parent)
if not result.start then
result.start = getFinish()
end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ result.optional = true
+ end
result.finish = getFinish()
result.firstFinish = result.finish
@@ -785,405 +745,534 @@ function parseType(parent)
return result
end
-local function parseAlias()
- local result = {
- type = 'doc.alias',
- }
- result.alias = parseName('doc.alias.name', result)
- if not result.alias then
- pushWarning {
- type = 'LUADOC_MISS_ALIAS_NAME',
- start = getFinish(),
- finish = getFinish(),
- }
- return nil
- end
- result.start = getStart()
- result.signs = parseSigns(result)
- result.extends = parseType(result)
- if not result.extends then
- pushWarning {
- type = 'LUADOC_MISS_ALIAS_EXTENDS',
- start = getFinish(),
- finish = getFinish(),
- }
- return nil
- end
- result.finish = getFinish()
- return result
-end
-
-local function parseParam()
- local result = {
- type = 'doc.param',
- }
- result.param = parseName('doc.param.name', result)
- or parseDots('doc.param.name', result)
- if not result.param then
- pushWarning {
- type = 'LUADOC_MISS_PARAM_NAME',
- start = getFinish(),
- finish = getFinish(),
+local docSwitch = util.switch()
+ : case 'class'
+ : call(function ()
+ local result = {
+ type = 'doc.class',
+ fields = {},
}
- return nil
- end
- if checkToken('symbol', '?', 1) then
- nextToken()
- result.optional = true
- end
- result.start = result.param.start
- result.finish = getFinish()
- result.extends = parseType(result)
- if not result.extends then
- pushWarning {
- type = 'LUADOC_MISS_PARAM_EXTENDS',
- start = getFinish(),
- finish = getFinish(),
- }
- return result
- end
- result.finish = getFinish()
- result.firstFinish = result.extends.firstFinish
- return result
-end
-
-local function parseReturn()
- local result = {
- type = 'doc.return',
- returns = {},
- }
- while true do
- local docType = parseType(result)
- if not docType then
- break
- end
- if not result.start then
- result.start = docType.start
- end
- if checkToken('symbol', '?', 1) then
- nextToken()
- docType.optional = true
+ result.class = parseName('doc.class.name', result)
+ if not result.class then
+ pushWarning {
+ type = 'LUADOC_MISS_CLASS_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
end
- docType.name = parseName('doc.return.name', docType)
- result.returns[#result.returns+1] = docType
- if not checkToken('symbol', ',', 1) then
- break
+ result.start = getStart()
+ result.finish = getFinish()
+ result.signs = parseSigns(result)
+ if not checkToken('symbol', ':', 1) then
+ return result
end
nextToken()
- end
- if #result.returns == 0 then
- return nil
- end
- result.finish = getFinish()
- return result
-end
-local function parseField()
- local result = {
- type = 'doc.field',
- }
- try(function ()
- local tp, value = nextToken()
- if tp == 'name' then
- if value == 'public'
- or value == 'protected'
- or value == 'private' then
- result.visible = value
- result.start = getStart()
- return true
+ result.extends = {}
+
+ while true do
+ local extend = parseName('doc.extends.name', result)
+ or parseTable(result)
+ if not extend then
+ pushWarning {
+ type = 'LUADOC_MISS_CLASS_EXTENDS_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
end
+ result.extends[#result.extends+1] = extend
+ result.finish = getFinish()
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
end
- return false
+ return result
end)
- result.field = parseName('doc.field.name', result)
- or parseIndexField('doc.field.name', result)
- if not result.field then
- pushWarning {
- type = 'LUADOC_MISS_FIELD_NAME',
- start = getFinish(),
- finish = getFinish(),
- }
- return nil
- end
- if not result.start then
- result.start = result.field.start
- end
- if checkToken('symbol', '?', 1) then
- nextToken()
- result.optional = true
- end
- result.extends = parseType(result)
- if not result.extends then
- pushWarning {
- type = 'LUADOC_MISS_FIELD_EXTENDS',
- start = getFinish(),
- finish = getFinish(),
- }
- return nil
- end
- result.finish = getFinish()
- return result
-end
-
-local function parseGeneric()
- local result = {
- type = 'doc.generic',
- generics = {},
- }
- while true do
- local object = {
- type = 'doc.generic.object',
- parent = result,
+ : case 'type'
+ : call(function ()
+ return parseType()
+ end)
+ : case 'alias'
+ : call(function ()
+ local result = {
+ type = 'doc.alias',
}
- object.generic = parseName('doc.generic.name', object)
- if not object.generic then
+ result.alias = parseName('doc.alias.name', result)
+ if not result.alias then
pushWarning {
- type = 'LUADOC_MISS_GENERIC_NAME',
+ type = 'LUADOC_MISS_ALIAS_NAME',
start = getFinish(),
finish = getFinish(),
}
return nil
end
- object.start = object.generic.start
- if not result.start then
- result.start = object.start
+ result.start = getStart()
+ result.signs = parseSigns(result)
+ result.extends = parseType(result)
+ if not result.extends then
+ pushWarning {
+ type = 'LUADOC_MISS_ALIAS_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
end
- if checkToken('symbol', ':', 1) then
+ result.finish = getFinish()
+ return result
+ end)
+ : case 'param'
+ : call(function ()
+ local result = {
+ type = 'doc.param',
+ }
+ result.param = parseName('doc.param.name', result)
+ or parseDots('doc.param.name', result)
+ if not result.param then
+ pushWarning {
+ type = 'LUADOC_MISS_PARAM_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ if checkToken('symbol', '?', 1) then
nextToken()
- object.extends = parseType(object)
+ result.optional = true
end
- object.finish = getFinish()
- result.generics[#result.generics+1] = object
- if not checkToken('symbol', ',', 1) then
- break
+ result.start = result.param.start
+ result.finish = getFinish()
+ result.extends = parseType(result)
+ if not result.extends then
+ pushWarning {
+ type = 'LUADOC_MISS_PARAM_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
end
- nextToken()
- end
- result.finish = getFinish()
- return result
-end
-
-local function parseVararg()
- local result = {
- type = 'doc.vararg',
- }
- result.vararg = parseType(result)
- if not result.vararg then
- pushWarning {
- type = 'LUADOC_MISS_VARARG_TYPE',
- start = getFinish(),
- finish = getFinish(),
+ result.finish = getFinish()
+ result.firstFinish = result.extends.firstFinish
+ return result
+ end)
+ : case 'return'
+ : call(function ()
+ local result = {
+ type = 'doc.return',
+ returns = {},
}
- return
- end
- result.start = result.vararg.start
- result.finish = result.vararg.finish
- return result
-end
-
-local function parseOverload()
- local tp, name = peekToken()
- if tp ~= 'name'
- or (name ~= 'fun' and name ~= 'async') then
- pushWarning {
- type = 'LUADOC_MISS_FUN_AFTER_OVERLOAD',
- start = getFinish(),
- finish = getFinish(),
+ while true do
+ local docType = parseType(result)
+ if not docType then
+ break
+ end
+ if not result.start then
+ result.start = docType.start
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ docType.optional = true
+ end
+ docType.name = parseName('doc.return.name', docType)
+ result.returns[#result.returns+1] = docType
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
+ end
+ if #result.returns == 0 then
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+ end)
+ : case 'field'
+ : call(function ()
+ local result = {
+ type = 'doc.field',
}
- return nil
- end
- local result = {
- type = 'doc.overload',
- }
- result.overload = parseFunction(result)
- if not result.overload then
- return nil
- end
- result.overload.parent = result
- result.start = result.overload.start
- result.finish = result.overload.finish
- return result
-end
-
-local function parseDeprecated()
- return {
- type = 'doc.deprecated',
- start = getFinish(),
- finish = getFinish(),
- }
-end
-
-local function parseMeta()
- return {
- type = 'doc.meta',
- start = getFinish(),
- finish = getFinish(),
- }
-end
-
-local function parseVersion()
- local result = {
- type = 'doc.version',
- versions = {},
- }
- while true do
- local tp, text = nextToken()
- if not tp then
+ try(function ()
+ local tp, value = nextToken()
+ if tp == 'name' then
+ if value == 'public'
+ or value == 'protected'
+ or value == 'private' then
+ result.visible = value
+ result.start = getStart()
+ return true
+ end
+ end
+ return false
+ end)
+ result.field = parseName('doc.field.name', result)
+ or parseIndexField('doc.field.name', result)
+ if not result.field then
pushWarning {
- type = 'LUADOC_MISS_VERSION',
+ type = 'LUADOC_MISS_FIELD_NAME',
start = getFinish(),
finish = getFinish(),
}
- break
+ return nil
end
if not result.start then
- result.start = getStart()
+ result.start = result.field.start
end
- local version = {
- type = 'doc.version.unit',
- parent = result,
- start = getStart(),
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ result.optional = true
+ end
+ result.extends = parseType(result)
+ if not result.extends then
+ pushWarning {
+ type = 'LUADOC_MISS_FIELD_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+ end)
+ : case 'generic'
+ : call(function ()
+ local result = {
+ type = 'doc.generic',
+ generics = {},
}
- if tp == 'symbol' then
- if text == '>' then
- version.ge = true
- elseif text == '<' then
- version.le = true
+ while true do
+ local object = {
+ type = 'doc.generic.object',
+ parent = result,
+ }
+ object.generic = parseName('doc.generic.name', object)
+ if not object.generic then
+ pushWarning {
+ type = 'LUADOC_MISS_GENERIC_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ object.start = object.generic.start
+ if not result.start then
+ result.start = object.start
end
- tp, text = nextToken()
+ if checkToken('symbol', ':', 1) then
+ nextToken()
+ object.extends = parseType(object)
+ end
+ object.finish = getFinish()
+ result.generics[#result.generics+1] = object
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
end
- if tp ~= 'name' then
+ result.finish = getFinish()
+ return result
+ end)
+ : case 'vararg'
+ : call(function ()
+ local result = {
+ type = 'doc.vararg',
+ }
+ result.vararg = parseType(result)
+ if not result.vararg then
pushWarning {
- type = 'LUADOC_MISS_VERSION',
- start = getStart(),
+ type = 'LUADOC_MISS_VARARG_TYPE',
+ start = getFinish(),
finish = getFinish(),
}
- break
+ return
end
- version.version = tonumber(text) or text
- version.finish = getFinish()
- result.versions[#result.versions+1] = version
- if not checkToken('symbol', ',', 1) then
- break
+ result.start = result.vararg.start
+ result.finish = result.vararg.finish
+ return result
+ end)
+ : case 'overload'
+ : call(function ()
+ local tp, name = peekToken()
+ if tp ~= 'name'
+ or (name ~= 'fun' and name ~= 'async') then
+ pushWarning {
+ type = 'LUADOC_MISS_FUN_AFTER_OVERLOAD',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
end
- nextToken()
- end
- if #result.versions == 0 then
- return nil
- end
- result.finish = getFinish()
- return result
-end
-
-local function parseSee()
- local result = {
- type = 'doc.see',
- }
- result.name = parseName('doc.see.name', result)
- if not result.name then
- return nil
- end
- result.start = result.name.start
- result.finish = result.name.finish
- if checkToken('symbol', '#', 1) then
- nextToken()
- result.field = parseName('doc.see.field', result)
- result.finish = getFinish()
- end
- return result
-end
-
-local function parseDiagnostic()
- local result = {
- type = 'doc.diagnostic',
- }
- local nextTP, mode = nextToken()
- if nextTP ~= 'name' then
- pushWarning {
- type = 'LUADOC_MISS_DIAG_MODE',
+ local result = {
+ type = 'doc.overload',
+ }
+ result.overload = parseFunction(result)
+ if not result.overload then
+ return nil
+ end
+ result.overload.parent = result
+ result.start = result.overload.start
+ result.finish = result.overload.finish
+ return result
+ end)
+ : case 'deprecated'
+ : call(function ()
+ return {
+ type = 'doc.deprecated',
start = getFinish(),
finish = getFinish(),
}
- return nil
- end
- result.mode = mode
- result.start = getStart()
- result.finish = getFinish()
- if mode ~= 'disable-next-line'
- and mode ~= 'disable-line'
- and mode ~= 'disable'
- and mode ~= 'enable' then
- pushWarning {
- type = 'LUADOC_ERROR_DIAG_MODE',
- start = result.start,
- finish = result.finish,
+ end)
+ : case 'meta'
+ : call(function ()
+ return {
+ type = 'doc.meta',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ end)
+ : case 'version'
+ : call(function ()
+ local result = {
+ type = 'doc.version',
+ versions = {},
}
- end
-
- if checkToken('symbol', ':', 1) then
- nextToken()
- result.names = {}
while true do
- local name = parseName('doc.diagnostic.name', result)
- if not name then
+ local tp, text = nextToken()
+ if not tp then
pushWarning {
- type = 'LUADOC_MISS_DIAG_NAME',
+ type = 'LUADOC_MISS_VERSION',
start = getFinish(),
finish = getFinish(),
}
- return result
+ break
+ end
+ if not result.start then
+ result.start = getStart()
end
- result.names[#result.names+1] = name
+ local version = {
+ type = 'doc.version.unit',
+ parent = result,
+ start = getStart(),
+ }
+ if tp == 'symbol' then
+ if text == '>' then
+ version.ge = true
+ elseif text == '<' then
+ version.le = true
+ end
+ tp, text = nextToken()
+ end
+ if tp ~= 'name' then
+ pushWarning {
+ type = 'LUADOC_MISS_VERSION',
+ start = getStart(),
+ finish = getFinish(),
+ }
+ break
+ end
+ version.version = tonumber(text) or text
+ version.finish = getFinish()
+ result.versions[#result.versions+1] = version
if not checkToken('symbol', ',', 1) then
break
end
nextToken()
end
- end
+ if #result.versions == 0 then
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+ end)
+ : case 'see'
+ : call(function ()
+ local result = {
+ type = 'doc.see',
+ }
+ result.name = parseName('doc.see.name', result)
+ if not result.name then
+ return nil
+ end
+ result.start = result.name.start
+ result.finish = result.name.finish
+ if checkToken('symbol', '#', 1) then
+ nextToken()
+ result.field = parseName('doc.see.field', result)
+ result.finish = getFinish()
+ end
+ return result
+ end)
+ : case 'diagnostic'
+ : call(function ()
+ local result = {
+ type = 'doc.diagnostic',
+ }
+ local nextTP, mode = nextToken()
+ if nextTP ~= 'name' then
+ pushWarning {
+ type = 'LUADOC_MISS_DIAG_MODE',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.mode = mode
+ result.start = getStart()
+ result.finish = getFinish()
+ if mode ~= 'disable-next-line'
+ and mode ~= 'disable-line'
+ and mode ~= 'disable'
+ and mode ~= 'enable' then
+ pushWarning {
+ type = 'LUADOC_ERROR_DIAG_MODE',
+ start = result.start,
+ finish = result.finish,
+ }
+ end
- result.finish = getFinish()
+ if checkToken('symbol', ':', 1) then
+ nextToken()
+ result.names = {}
+ while true do
+ local name = parseName('doc.diagnostic.name', result)
+ if not name then
+ pushWarning {
+ type = 'LUADOC_MISS_DIAG_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
+ end
+ result.names[#result.names+1] = name
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
+ end
+ end
- return result
-end
+ result.finish = getFinish()
-local function parseModule()
- local result = {
- type = 'doc.module',
- start = getFinish(),
- finish = getFinish(),
- }
- local tp, content = peekToken()
- if tp == 'string' then
- result.module = content
- nextToken()
- result.start = getStart()
+ return result
+ end)
+ : case 'module'
+ : call(function ()
+ local result = {
+ type = 'doc.module',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ local tp, content = peekToken()
+ if tp == 'string' then
+ result.module = content
+ nextToken()
+ result.start = getStart()
+ result.finish = getFinish()
+ result.smark = getMark()
+ else
+ pushWarning {
+ type = 'LUADOC_MISS_MODULE_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ end
+ return result
+ end)
+ : case 'async'
+ : call(function ()
+ return {
+ type = 'doc.async',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ end)
+ : case 'nodiscard'
+ : call(function ()
+ return {
+ type = 'doc.nodiscard',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ end)
+ : case 'as'
+ : call(function ()
+ local result = {
+ type = 'doc.as',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ result.as = parseType(result)
result.finish = getFinish()
- result.smark = getMark()
- else
- pushWarning {
- type = 'LUADOC_MISS_MODULE_NAME',
+ return result
+ end)
+ : case 'cast'
+ : call(function ()
+ local result = {
+ type = 'doc.cast',
start = getFinish(),
finish = getFinish(),
+ casts = {},
}
- end
- return result
-end
-local function parseAsync()
- return {
- type = 'doc.async',
- start = getFinish(),
- finish = getFinish(),
- }
-end
+ local loc = parseName('doc.cast.name', result)
+ if not loc then
+ pushWarning {
+ type = 'LUADOC_MISS_LOCAL_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
+ end
-local function parseNoDiscard()
- return {
- type = 'doc.nodiscard',
- start = getFinish(),
- finish = getFinish(),
- }
-end
+ result.loc = loc
+ result.finish = loc.finish
+
+ while true do
+ local block = {
+ type = 'doc.cast.block',
+ parent = result,
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ if checkToken('symbol', '+', 1) then
+ block.mode = '+'
+ nextToken()
+ block.start = getStart()
+ block.finish = getFinish()
+ elseif checkToken('symbol', '-', 1) then
+ block.mode = '-'
+ nextToken()
+ block.start = getStart()
+ block.finish = getFinish()
+ end
+
+ if checkToken('symbol', '?', 1) then
+ block.optional = true
+ nextToken()
+ block.start = block.start or getStart()
+ block.finish = block.finish
+ else
+ block.extends = parseType(block)
+ if block.extends then
+ block.start = block.start or block.extends.start
+ block.finish = block.extends.finish
+ end
+ end
+
+ if block.optional or block.extends then
+ result.casts[#result.casts+1] = block
+ end
+
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ else
+ break
+ end
+ end
+
+ return result
+ end)
local function convertTokens()
local tp, text = nextToken()
@@ -1198,41 +1287,7 @@ local function convertTokens()
}
return nil
end
- if text == 'class' then
- return parseClass()
- elseif text == 'type' then
- return parseType()
- elseif text == 'alias' then
- return parseAlias()
- elseif text == 'param' then
- return parseParam()
- elseif text == 'return' then
- return parseReturn()
- elseif text == 'field' then
- return parseField()
- elseif text == 'generic' then
- return parseGeneric()
- elseif text == 'vararg' then
- return parseVararg()
- elseif text == 'overload' then
- return parseOverload()
- elseif text == 'deprecated' then
- return parseDeprecated()
- elseif text == 'meta' then
- return parseMeta()
- elseif text == 'version' then
- return parseVersion()
- elseif text == 'see' then
- return parseSee()
- elseif text == 'diagnostic' then
- return parseDiagnostic()
- elseif text == 'module' then
- return parseModule()
- elseif text == 'async' then
- return parseAsync()
- elseif text == 'nodiscard' then
- return parseNoDiscard()
- end
+ return docSwitch(text)
end
local function trimTailComment(text)
@@ -1257,7 +1312,8 @@ end
local function buildLuaDoc(comment)
local text = comment.text
- local _, startPos = text:find('^%-%s*@')
+ local startPos = (comment.type == 'comment.short' and text:match '^%-%s*@()')
+ or (comment.type == 'comment.long' and text:match '^@()')
if not startPos then
return {
type = 'doc.comment',
@@ -1268,9 +1324,9 @@ local function buildLuaDoc(comment)
}
end
- local doc = text:sub(startPos + 1)
+ local doc = text:sub(startPos)
- parseTokens(doc, comment.start + startPos + 1)
+ parseTokens(doc, comment.start + startPos)
local result = convertTokens()
if result then
result.range = comment.finish
@@ -1313,16 +1369,21 @@ local function isNextLine(binded, doc)
return false
end
local lastDoc = binded[#binded]
- if lastDoc.type == 'doc.type' then
+ if lastDoc.type == 'doc.type'
+ or lastDoc.type == 'doc.module' then
return false
end
if lastDoc.type == 'doc.class'
or lastDoc.type == 'doc.field' then
if doc.type ~= 'doc.field'
- and doc.type ~= 'doc.comment' then
+ and doc.type ~= 'doc.comment'
+ and doc.type ~= 'doc.overload' then
return false
end
end
+ if doc.type == 'doc.cast' then
+ return false
+ end
local lastRow = guide.rowColOf(lastDoc.finish)
local newRow = guide.rowColOf(doc.start)
return newRow - lastRow == 1
@@ -1400,11 +1461,13 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish)
if src.start >= start then
if src.type == 'local'
or src.type == 'self'
+ or src.type == 'setlocal'
or src.type == 'setglobal'
or src.type == 'tablefield'
or src.type == 'tableindex'
or src.type == 'setfield'
or src.type == 'setindex'
+ or src.type == 'setmethod'
or src.type == 'function' then
src.bindDocs = binded
bindSources[#bindSources+1] = src
diff --git a/script/parser/newparser.lua b/script/parser/newparser.lua
index e226417f..630c12c2 100644
--- a/script/parser/newparser.lua
+++ b/script/parser/newparser.lua
@@ -117,6 +117,7 @@ local Specials = {
['xpcall'] = true,
['pairs'] = true,
['ipairs'] = true,
+ ['assert'] = true,
}
local UnarySymbol = {
@@ -537,6 +538,7 @@ local function skipComment(isAction)
if longComment then
longComment.type = 'comment.long'
longComment.text = longComment[1]
+ longComment.mark = longComment[2]
longComment[1] = nil
longComment[2] = nil
State.comms[#State.comms+1] = longComment
@@ -689,9 +691,6 @@ local function parseLocalAttrs()
end
local function createLocal(obj, attrs)
- if not obj then
- return nil
- end
obj.type = 'local'
obj.effect = obj.finish
@@ -2891,7 +2890,11 @@ local function parseLocal()
pushActionIntoCurrentChunk(loc)
skipSpace()
parseMultiVars(loc, parseName, true)
- loc.effect = lastRightPosition()
+ if loc.value then
+ loc.effect = loc.value.finish
+ else
+ loc.effect = loc.finish
+ end
return loc
end
@@ -2946,13 +2949,22 @@ local function parseReturn()
end
pushActionIntoCurrentChunk(rtn)
for i = #Chunk, 1, -1 do
- local func = Chunk[i]
- if func.type == 'function'
- or func.type == 'main' then
- if not func.returns then
- func.returns = {}
+ local block = Chunk[i]
+ if block.type == 'function'
+ or block.type == 'main' then
+ if not block.returns then
+ block.returns = {}
end
- func.returns[#func.returns+1] = rtn
+ block.returns[#block.returns+1] = rtn
+ break
+ end
+ end
+ for i = #Chunk, 1, -1 do
+ local block = Chunk[i]
+ if block.type == 'ifblock'
+ or block.type == 'elseifblock'
+ or block.type == 'else' then
+ block.hasReturn = true
break
end
end
@@ -3052,6 +3064,15 @@ local function parseGoTo()
break
end
end
+ for i = #Chunk, 1, -1 do
+ local chunk = Chunk[i]
+ if chunk.type == 'ifblock'
+ or chunk.type == 'elseifblock'
+ or chunk.type == 'elseblock' then
+ chunk.hasGoTo = true
+ break
+ end
+ end
pushActionIntoCurrentChunk(action)
return action
@@ -3586,6 +3607,15 @@ local function parseBreak()
break
end
end
+ for i = #Chunk, 1, -1 do
+ local chunk = Chunk[i]
+ if chunk.type == 'ifblock'
+ or chunk.type == 'elseifblock'
+ or chunk.type == 'elseblock' then
+ chunk.hasBreak = true
+ break
+ end
+ end
if not ok and Mode == 'Lua' then
pushError {
type = 'BREAK_OUTSIDE',
diff --git a/script/proto/define.lua b/script/proto/define.lua
index 389cdf88..fb60c56c 100644
--- a/script/proto/define.lua
+++ b/script/proto/define.lua
@@ -9,10 +9,10 @@ m.DiagnosticSeverity = {
}
---@alias DiagnosticDefaultSeverity
----| '"Hint"'
----| '"Information"'
----| '"Warning"'
----| '"Error"'
+---| 'Hint'
+---| 'Information'
+---| 'Warning'
+---| 'Error'
--- 诊断类型与默认等级
---@type table<string, DiagnosticDefaultSeverity>
@@ -29,6 +29,7 @@ m.DiagnosticDefaultSeverity = {
['newline-call'] = 'Information',
['newfield-call'] = 'Warning',
['redundant-parameter'] = 'Warning',
+ ['missing-parameter'] = 'Warning',
['redundant-return'] = 'Warning',
['ambiguity-1'] = 'Warning',
['lowercase-global'] = 'Information',
@@ -47,6 +48,7 @@ m.DiagnosticDefaultSeverity = {
['await-in-sync'] = 'Warning',
['not-yieldable'] = 'Warning',
['discard-returns'] = 'Warning',
+ ['need-check-nil'] = 'Warning',
['type-check'] = 'Warning',
['duplicate-doc-alias'] = 'Warning',
@@ -63,9 +65,9 @@ m.DiagnosticDefaultSeverity = {
}
---@alias DiagnosticDefaultNeededFileStatus
----| '"Any"'
----| '"Opened"'
----| '"None"'
+---| 'Any'
+---| 'Opened'
+---| 'None'
-- 文件状态
m.FileStatus = {
@@ -88,6 +90,7 @@ m.DiagnosticDefaultNeededFileStatus = {
['newline-call'] = 'Any',
['newfield-call'] = 'Any',
['redundant-parameter'] = 'Opened',
+ ['missing-parameter'] = 'Opened',
['redundant-return'] = 'Opened',
['ambiguity-1'] = 'Any',
['lowercase-global'] = 'Any',
@@ -106,6 +109,7 @@ m.DiagnosticDefaultNeededFileStatus = {
['await-in-sync'] = 'None',
['not-yieldable'] = 'None',
['discard-returns'] = 'Opened',
+ ['need-check-nil'] = 'Opened',
['type-check'] = 'None',
['duplicate-doc-alias'] = 'Any',
diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua
index b359c21c..15b08d49 100644
--- a/script/provider/diagnostic.lua
+++ b/script/provider/diagnostic.lua
@@ -128,12 +128,17 @@ local function mergeDiags(a, b, c)
merge(b)
merge(c)
+ if #t == 0 then
+ return nil
+ end
+
return t
end
+-- enable `push`, disable `clear`
function m.clear(uri)
await.close('diag:' .. uri)
- if not m.cache[uri] then
+ if m.cache[uri] == nil then
return
end
m.cache[uri] = nil
@@ -144,6 +149,7 @@ function m.clear(uri)
log.info('clearDiagnostics', uri)
end
+-- enable `push` and `send`
function m.clearCache(uri)
m.cache[uri] = false
end
@@ -251,14 +257,7 @@ function m.doDiagnostic(uri, isScopeDiag)
version = version,
diagnostics = full,
})
- if #full > 0 then
- log.debug('publishDiagnostics', uri, #full)
- end
- end
-
- -- always re-sent diagnostics of current file
- if not isScopeDiag then
- m.cache[uri] = nil
+ log.debug('publishDiagnostics', uri, #full)
end
pushResult()
@@ -435,6 +434,7 @@ files.watch(function (ev, uri) ---@async
m.refresh(uri)
elseif ev == 'open' then
if ws.isReady(uri) then
+ m.clearCache(uri)
xpcall(m.doDiagnostic, log.error, uri)
end
elseif ev == 'close' then
diff --git a/script/provider/provider.lua b/script/provider/provider.lua
index b8b101ed..08b6ca93 100644
--- a/script/provider/provider.lua
+++ b/script/provider/provider.lua
@@ -42,8 +42,9 @@ local function updateConfig(uri)
end
local rc = cfgLoader.loadRCConfig(folder.uri, '.luarc.json')
+ or cfgLoader.loadRCConfig(folder.uri, '.luarc.jsonc')
if rc then
- log.info('Load config from luarc.json', folder.uri)
+ log.info('Load config from .luarc.json/.luarc.jsonc', folder.uri)
log.debug(inspect(rc))
end
@@ -91,6 +92,14 @@ filewatch.event(function (ev, path) ---@async
end
end
end
+ if util.stringEndWith(path, '.luarc.jsonc') then
+ for _, scp in ipairs(workspace.folders) do
+ local rcPath = workspace.getAbsolutePath(scp.uri, '.luarc.jsonc')
+ if path == rcPath then
+ updateConfig(scp.uri)
+ end
+ end
+ end
end)
m.register 'initialize' {
@@ -226,7 +235,6 @@ m.register 'workspace/didRenameFiles' {
}
m.register 'textDocument/didOpen' {
- ---@async
function (params)
local doc = params.textDocument
local scheme = furi.split(doc.uri)
@@ -235,7 +243,6 @@ m.register 'textDocument/didOpen' {
end
local uri = files.getRealUri(doc.uri)
log.debug('didOpen', uri)
- workspace.awaitReady(uri)
local text = doc.text
files.setText(uri, text, true, function (file)
file.version = doc.version
@@ -257,13 +264,14 @@ m.register 'textDocument/didClose' {
}
m.register 'textDocument/didChange' {
- ---@async
function (params)
local doc = params.textDocument
+ local scheme = furi.split(doc.uri)
+ if scheme ~= 'file' then
+ return
+ end
local changes = params.contentChanges
local uri = files.getRealUri(doc.uri)
- workspace.awaitReady(uri)
- --log.debug('changes', util.dump(changes))
local text = files.getOriginText(uri) or ''
local rows = files.getCachedRows(uri)
text, rows = tm(text, rows, changes)
@@ -521,7 +529,8 @@ m.register 'textDocument/completion' {
local count, max = workspace.getLoadingProcess(uri)
return {
{
- label = lang.script('HOVER_WS_LOADING', count, max),textEdit = {
+ label = lang.script('HOVER_WS_LOADING', count, max),
+ textEdit = {
range = {
start = params.position,
['end'] = params.position,
diff --git a/script/pub/pub.lua b/script/pub/pub.lua
index e73aea51..47591ee6 100644
--- a/script/pub/pub.lua
+++ b/script/pub/pub.lua
@@ -124,7 +124,7 @@ end
--- 通过 jumpQueue 可以插队
---@param name string
---@param params any
----@param callback function
+---@param callback? function
function m.task(name, params, callback)
local info = {
id = counter(),
diff --git a/script/service/telemetry.lua b/script/service/telemetry.lua
index 50af39b1..2e52def2 100644
--- a/script/service/telemetry.lua
+++ b/script/service/telemetry.lua
@@ -99,7 +99,7 @@ timer.wait(5, function ()
end
local suc, link = pcall(net.connect, 'tcp', 'moe-moe.love', 11577)
if not suc then
- suc, link = pcall(net.connect, 'tcp', '154.23.191.94', 11577)
+ suc, link = pcall(net.connect, 'tcp', '154.23.191.39', 11577)
end
if not suc or not link then
return
diff --git a/script/utility.lua b/script/utility.lua
index 5a52e417..47b0c8d8 100644
--- a/script/utility.lua
+++ b/script/utility.lua
@@ -83,7 +83,7 @@ local m = {}
--- 打印表的结构
---@param tbl table
----@param option table {optional = 'self'}
+---@param option? table
---@return string
function m.dump(tbl, option)
if not option then
@@ -315,8 +315,8 @@ function m.saveFile(path, content)
end
--- 计数器
----@param init integer {optional = 'after'}
----@param step integer {optional = 'after'}
+---@param init? integer
+---@param step? integer
---@return fun():integer
function m.counter(init, step)
if not step then
@@ -346,8 +346,8 @@ function m.sortPairs(t, sorter)
end
--- 深拷贝(不处理元表)
----@param source table
----@param target table {optional = 'self'}
+---@param source table
+---@param target? table
function m.deepCopy(source, target)
local mark = {}
local function copy(a, b)
@@ -566,7 +566,7 @@ end
---遍历文本的每一行
---@param text string
----@param keepNL boolean # 保留换行符
+---@param keepNL? boolean # 保留换行符
---@return fun(text:string):string, integer
function m.eachLine(text, keepNL)
local offset = 1
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 8126f393..75620d19 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1,10 +1,6 @@
local guide = require 'parser.guide'
local util = require 'utility'
-local localID = require 'vm.local-id'
-local globalMgr = require 'vm.global-manager'
-local signMgr = require 'vm.sign'
local config = require 'config'
-local genericMgr = require 'vm.generic'
local rpath = require 'workspace.require-path'
local files = require 'files'
---@class vm
@@ -13,7 +9,6 @@ local vm = require 'vm.vm'
---@class parser.object
---@field _compiledNodes boolean
---@field _node vm.node
----@field _localBase table
---@field _globalBase table
local searchFieldSwitch = util.switch()
@@ -54,7 +49,7 @@ local searchFieldSwitch = util.switch()
: case 'string'
: call(function (suri, source, key, ref, pushResult)
-- change to `string: stringlib` ?
- local stringlib = globalMgr.getGlobal('type', 'stringlib')
+ local stringlib = vm.getGlobal('type', 'stringlib')
if stringlib then
vm.getClassFields(suri, stringlib, key, ref, pushResult)
end
@@ -64,9 +59,9 @@ local searchFieldSwitch = util.switch()
: call(function (suri, node, key, ref, pushResult)
local fields
if key then
- fields = localID.getSources(node, key)
+ fields = vm.getLocalSources(node, key)
else
- fields = localID.getFields(node)
+ fields = vm.getLocalFields(node)
end
if fields then
for _, src in ipairs(fields) do
@@ -119,7 +114,7 @@ local searchFieldSwitch = util.switch()
if type(key) ~= 'string' then
return
end
- local global = globalMgr.getGlobal('variable', node.name, key)
+ local global = vm.getGlobal('variable', node.name, key)
if global then
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -131,7 +126,7 @@ local searchFieldSwitch = util.switch()
end
end
else
- local globals = globalMgr.getFields('variable', node.name)
+ local globals = vm.getGlobalFields('variable', node.name)
for _, global in ipairs(globals) do
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -158,7 +153,7 @@ local searchFieldSwitch = util.switch()
if type(key) ~= 'string' then
return
end
- local global = globalMgr.getGlobal('variable', node.name, key)
+ local global = vm.getGlobal('variable', node.name, key)
if global then
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -168,7 +163,7 @@ local searchFieldSwitch = util.switch()
end
end
else
- local globals = globalMgr.getFields('variable', node.name)
+ local globals = vm.getGlobalFields('variable', node.name)
for _, global in ipairs(globals) do
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -185,7 +180,7 @@ local searchFieldSwitch = util.switch()
end)
-function vm.getClassFields(suri, node, key, ref, pushResult)
+function vm.getClassFields(suri, object, key, ref, pushResult)
local mark = {}
local function searchClass(class, searchedFields)
@@ -201,11 +196,51 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
local hasFounded = {}
for _, field in ipairs(set.fields) do
local fieldKey = guide.getKeyName(field)
- if key == nil
- or fieldKey == key then
- if not searchedFields[fieldKey] then
- pushResult(field)
- hasFounded[fieldKey] = true
+ if fieldKey then
+ -- ---@field x boolean -> class.x
+ if key == nil
+ or fieldKey == key then
+ if not searchedFields[fieldKey] then
+ pushResult(field)
+ hasFounded[fieldKey] = true
+ end
+ end
+ end
+ if not hasFounded[fieldKey] then
+ local keyType = type(key)
+ if keyType == 'table' then
+ -- ---@field [integer] boolean -> class[integer]
+ local fieldNode = vm.compileNode(field.field)
+ if vm.isSubType(suri, key.name, fieldNode) then
+ local nkey = '|' .. key.name
+ if not searchedFields[nkey] then
+ pushResult(field)
+ hasFounded[nkey] = true
+ end
+ end
+ else
+ local typeName
+ if keyType == 'number' then
+ if math.tointeger(key) then
+ typeName = 'integer'
+ else
+ typeName = 'number'
+ end
+ elseif keyType == 'boolean'
+ or keyType == 'string' then
+ typeName = keyType
+ end
+ if typeName then
+ -- ---@field [integer] boolean -> class[1]
+ local fieldNode = vm.compileNode(field.field)
+ if vm.isSubType(suri, typeName, fieldNode) then
+ local nkey = '|' .. typeName
+ if not searchedFields[nkey] then
+ pushResult(field)
+ hasFounded[nkey] = true
+ end
+ end
+ end
end
end
end
@@ -214,19 +249,23 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
for _, src in ipairs(set.bindSources) do
searchFieldSwitch(src.type, suri, src, key, ref, function (field)
local fieldKey = guide.getKeyName(field)
- if not searchedFields[fieldKey]
- and guide.isSet(field) then
- hasFounded[fieldKey] = true
- pushResult(field)
+ if fieldKey then
+ if not searchedFields[fieldKey]
+ and guide.isSet(field) then
+ hasFounded[fieldKey] = true
+ pushResult(field)
+ end
end
end)
if src.value and src.value.type == 'table' then
searchFieldSwitch('table', suri, src.value, key, ref, function (field)
local fieldKey = guide.getKeyName(field)
- if not searchedFields[fieldKey]
- and guide.isSet(field) then
- hasFounded[fieldKey] = true
- pushResult(field)
+ if fieldKey then
+ if not searchedFields[fieldKey]
+ and guide.isSet(field) then
+ hasFounded[fieldKey] = true
+ pushResult(field)
+ end
end
end)
end
@@ -239,7 +278,7 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
end
for _, extend in ipairs(set.extends) do
if extend.type == 'doc.extends.name' then
- local extendType = globalMgr.getGlobal('type', extend[1])
+ local extendType = vm.getGlobal('type', extend[1])
if extendType then
searchClass(extendType, searchedFields)
end
@@ -253,12 +292,12 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
local function searchGlobal(class)
if class.cate == 'type' and class.name == '_G' then
if key == nil then
- local sets = globalMgr.getGlobalSets(suri, 'variable')
+ local sets = vm.getGlobalSets(suri, 'variable')
for _, set in ipairs(sets) do
pushResult(set)
end
else
- local global = globalMgr.getGlobal('variable', key)
+ local global = vm.getGlobal('variable', key)
if global then
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -268,8 +307,8 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
end
end
- searchClass(node)
- searchGlobal(node)
+ searchClass(object)
+ searchGlobal(object)
end
---@class parser.object
@@ -283,10 +322,13 @@ local function getObjectSign(source)
end
source._sign = false
if source.type == 'function' then
+ if not source.bindDocs then
+ return false
+ end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.generic' then
if not source._sign then
- source._sign = signMgr()
+ source._sign = vm.createSign()
break
end
end
@@ -314,14 +356,18 @@ local function getObjectSign(source)
if not hasGeneric then
return false
end
- source._sign = signMgr()
+ source._sign = vm.createSign()
if source.type == 'doc.type.function' then
for _, arg in ipairs(source.args) do
- local argNode = vm.compileNode(arg.extends)
- if arg.optional then
- argNode:addOptional()
+ if arg.extends then
+ local argNode = vm.compileNode(arg.extends)
+ if arg.optional then
+ argNode:addOptional()
+ end
+ source._sign:addSign(argNode)
+ else
+ source._sign:addSign(vm.createNode())
end
- source._sign:addSign(argNode)
end
end
end
@@ -354,7 +400,7 @@ function vm.getReturnOfFunction(func, index)
if not sign then
return rtn
end
- return genericMgr(rtn, sign)
+ return vm.createGeneric(rtn, sign)
end
end
@@ -455,6 +501,9 @@ local function getReturn(func, index, args)
result:merge(rnode)
end
end
+ if result and returnNode:isOptional() then
+ result:addOptional()
+ end
end
end
end
@@ -462,6 +511,25 @@ local function getReturn(func, index, args)
return result
end
+---@param source parser.object
+---@return boolean
+local function bindAs(source)
+ local root = guide.getRoot(source)
+ local docs = root.docs
+ if not docs then
+ return
+ end
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.as' and doc.originalComment.start == source.finish + 2 then
+ if doc.as then
+ vm.setNode(source, vm.compileNode(doc.as), true)
+ end
+ return true
+ end
+ end
+ return false
+end
+
local function bindDocs(source)
local isParam = source.parent.type == 'funcargs'
or source.parent.type == 'in'
@@ -485,7 +553,11 @@ local function bindDocs(source)
end
if doc.type == 'doc.param' then
if isParam and source[1] == doc.param[1] then
- vm.setNode(source, vm.compileNode(doc))
+ local node = vm.compileNode(doc)
+ if doc.optional then
+ node:addOptional()
+ end
+ vm.setNode(source, node)
return true
end
end
@@ -503,12 +575,17 @@ local function bindDocs(source)
vm.setNode(source, vm.compileNode(ast))
return true
end
+ if doc.type == 'doc.overload' then
+ if not isParam then
+ vm.setNode(source, vm.compileNode(doc))
+ end
+ end
end
return false
end
local function compileByLocalID(source)
- local sources = localID.getSources(source)
+ local sources = vm.getLocalSources(source)
if not sources then
return
end
@@ -571,7 +648,7 @@ local function selectNode(source, list, index)
if exp.type == 'call' then
result = getReturn(exp.node, index, exp.args)
if not result then
- vm.setNode(source, globalMgr.getGlobal('type', 'unknown'))
+ vm.setNode(source, vm.declareGlobal('type', 'unknown'))
return vm.getNode(source)
end
else
@@ -597,7 +674,7 @@ local function selectNode(source, list, index)
end
end
if not hasKnownType then
- rtnNode:merge(globalMgr.getGlobal('type', 'unknown'))
+ rtnNode:merge(vm.declareGlobal('type', 'unknown'))
end
vm.setNode(source, rtnNode)
return rtnNode
@@ -664,10 +741,21 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
for n in callNode:eachObject() do
if n.type == 'function' then
+ local sign = getObjectSign(n)
local farg = getFuncArg(n, myIndex)
if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
+ if fn.type == 'doc.type.function' then
+ if sign then
+ local generic = vm.createGeneric(fn, sign)
+ local args = {}
+ for i = fixIndex + 1, myIndex - 1 do
+ args[#args+1] = call.args[i]
+ end
+ fn = generic:resolve(guide.getUri(call), args)
+ end
+ end
vm.setNode(arg, fn)
end
end
@@ -716,29 +804,19 @@ function vm.compileCallArg(arg, call, index)
if call.node.special == 'pcall'
or call.node.special == 'xpcall' then
local fixIndex = call.node.special == 'pcall' and 1 or 2
- callNode = vm.compileNode(call.args[1])
- compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex)
+ if call.args and call.args[1] then
+ callNode = vm.compileNode(call.args[1])
+ compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex)
+ end
end
return vm.getNode(arg)
end
---@param source parser.object
---@return vm.node
-local function compileLocalBase(source)
- if not source._localBase then
- source._localBase = {
- type = 'localbase',
- parent = source,
- }
- end
- local baseNode = vm.getNode(source._localBase)
- if baseNode then
- return baseNode
- end
- baseNode = vm.createNode()
- vm.setNode(source._localBase, baseNode, true)
-
+local function compileLocal(source)
vm.setNode(source, source)
+
local hasMarkDoc
if source.bindDocs then
hasMarkDoc = bindDocs(source)
@@ -788,14 +866,19 @@ local function compileLocalBase(source)
if n.type == 'doc.type.function' then
for index, arg in ipairs(n.args) do
if func.args[index] == source then
- vm.setNode(source, vm.compileNode(arg))
+ local argNode = vm.compileNode(arg)
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
+ end
+ end
hasDocArg = true
end
end
end
end
if not hasDocArg then
- vm.setNode(source, globalMgr.getGlobal('type', 'any'))
+ vm.setNode(source, vm.declareGlobal('type', 'any'))
end
end
-- for x in ... do
@@ -805,15 +888,10 @@ local function compileLocalBase(source)
-- for x = ... do
if source.parent.type == 'loop' then
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.compileNode(source.parent)
end
- baseNode:merge(vm.getNode(source))
- vm.removeNode(source)
-
- baseNode:setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue)
-
- return baseNode
+ vm.getNode(source):setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue)
end
local compilerSwitch = util.switch()
@@ -867,41 +945,79 @@ local compilerSwitch = util.switch()
end)
: case 'paren'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
if source.exp then
vm.setNode(source, vm.compileNode(source.exp))
end
end)
: case 'local'
: case 'self'
+ ---@param source parser.object
: call(function (source)
- local baseNode = compileLocalBase(source)
- vm.setNode(source, baseNode, true)
- if not baseNode:getData 'hasDefined' and source.ref then
+ compileLocal(source)
+ local refs = source.ref
+ if not refs then
+ return
+ end
+
+ local hasMark = vm.getNode(source):getData 'hasDefined'
+
+ local runner = vm.createRunner(source)
+ runner:launch(function (src, node)
+ if src.type == 'setlocal' then
+ if src.bindDocs then
+ for _, doc in ipairs(src.bindDocs) do
+ if doc.type == 'doc.type' then
+ vm.setNode(src, vm.compileNode(doc), true)
+ return vm.getNode(src)
+ end
+ end
+ end
+ if src.value and guide.isLiteral(src.value) then
+ if src.value.type == 'table' then
+ vm.setNode(src, vm.createNode(src.value), true)
+ else
+ vm.setNode(src, vm.compileNode(src.value), true)
+ end
+ elseif src.value
+ and src.value.type == 'binary'
+ and src.value.op and src.value.op.type == 'or'
+ and src.value[1] and src.value[1].type == 'getlocal' and src.value[1].node == source then
+ -- x = x or 1
+ vm.setNode(src, vm.compileNode(src.value))
+ else
+ vm.setNode(src, node, true)
+ end
+ return vm.getNode(src)
+ elseif src.type == 'getlocal' then
+ if bindAs(src) then
+ return
+ end
+ vm.setNode(src, node, true)
+ end
+ end)
+
+ if not hasMark then
+ local parentFunc = guide.getParentFunction(source)
for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal' then
- vm.setNode(source, vm.compileNode(ref))
+ if ref.type == 'setlocal'
+ and guide.getParentFunction(ref) == parentFunc then
+ vm.setNode(source, vm.getNode(ref))
end
end
end
end)
: case 'setlocal'
: call(function (source)
- local baseNode = compileLocalBase(source.node)
- if not baseNode:getData 'hasDefined' and source.value then
- if source.value.type == 'table' then
- vm.setNode(source, source.value)
- else
- vm.setNode(source, vm.compileNode(source.value))
- end
- end
- baseNode:merge(vm.getNode(source))
- vm.setNode(source, baseNode, true)
vm.compileNode(source.node)
end)
: case 'getlocal'
: call(function (source)
- local baseNode = compileLocalBase(source.node)
- vm.setNode(source, baseNode, true)
+ if bindAs(source) then
+ return
+ end
vm.compileNode(source.node)
end)
: case 'setfield'
@@ -924,6 +1040,9 @@ local compilerSwitch = util.switch()
: case 'getmethod'
: case 'getindex'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
compileByLocalID(source)
local key = guide.getKeyName(source)
if key == nil and source.index then
@@ -959,6 +1078,9 @@ local compilerSwitch = util.switch()
end)
: case 'getglobal'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
if source.node[1] ~= '_ENV' then
return
end
@@ -1019,7 +1141,7 @@ local compilerSwitch = util.switch()
end)
end
if hasGeneric then
- vm.setNode(source, genericMgr(rtn, sign))
+ vm.setNode(source, vm.createGeneric(rtn, sign))
else
vm.setNode(source, vm.compileNode(rtn))
end
@@ -1092,29 +1214,44 @@ local compilerSwitch = util.switch()
-- for k, v in pairs(t) do
--> for k, v in iterator, status, initValue do
--> local k, v = iterator(status, initValue)
- source._iterator = {}
- source._iterArgs = {{}, {}}
- -- iterator
- selectNode(source._iterator, source.exps, 1)
- -- status
- selectNode(source._iterArgs[1], source.exps, 2)
- -- initValue
- selectNode(source._iterArgs[2], source.exps, 3)
- end
+ source._iterator = {
+ type = 'dummyfunc',
+ parent = source,
+ }
+ source._iterArgs = {{},{}}
+ end
+ -- iterator
+ selectNode(source._iterator, source.exps, 1)
+ -- status
+ selectNode(source._iterArgs[1], source.exps, 2)
+ -- initValue
+ selectNode(source._iterArgs[2], source.exps, 3)
if source.keys then
for i, loc in ipairs(source.keys) do
local node = getReturn(source._iterator, i, source._iterArgs)
if node then
+ if i == 1 then
+ node:removeOptional()
+ end
vm.setNode(loc, node)
end
end
end
end)
+ : case 'loop'
+ : call(function (source)
+ if source.loc then
+ vm.setNode(source.loc, vm.declareGlobal('type', 'integer'))
+ end
+ end)
: case 'doc.type'
: call(function (source)
for _, typeUnit in ipairs(source.types) do
vm.setNode(source, vm.compileNode(typeUnit))
end
+ if source.optional then
+ vm.getNode(source):addOptional()
+ end
end)
: case 'doc.type.integer'
: case 'doc.type.string'
@@ -1130,7 +1267,13 @@ local compilerSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
vm.setNode(source, source)
- local global = globalMgr.getGlobal('type', source.node[1])
+ if not source.node[1] then
+ return
+ end
+ local global = vm.getGlobal('type', source.node[1])
+ if not global then
+ return
+ end
for _, set in ipairs(global:getSets(uri)) do
if set.type == 'doc.class' then
if set.extends then
@@ -1161,14 +1304,22 @@ local compilerSwitch = util.switch()
if not source.extends then
return
end
- vm.setNode(source, vm.compileNode(source.extends))
+ local fieldNode = vm.compileNode(source.extends)
+ if source.optional then
+ fieldNode:addOptional()
+ end
+ vm.setNode(source, fieldNode)
end)
: case 'doc.type.field'
: call(function (source)
if not source.extends then
return
end
- vm.setNode(source, vm.compileNode(source.extends))
+ local fieldNode = vm.compileNode(source.extends)
+ if source.optional then
+ fieldNode:addOptional()
+ end
+ vm.setNode(source, fieldNode)
end)
: case 'doc.param'
: call(function (source)
@@ -1208,7 +1359,7 @@ local compilerSwitch = util.switch()
end)
: case 'doc.see.name'
: call(function (source)
- local type = globalMgr.getGlobal('type', source[1])
+ local type = vm.getGlobal('type', source[1])
if type then
vm.setNode(source, vm.compileNode(type))
end
@@ -1218,7 +1369,10 @@ local compilerSwitch = util.switch()
if source.extends then
vm.setNode(source, vm.compileNode(source.extends))
else
- vm.setNode(source, globalMgr.getGlobal('type', 'any'))
+ vm.setNode(source, vm.declareGlobal('type', 'any'))
+ end
+ if source.optional then
+ vm.getNode(source):addOptional()
end
end)
: case 'generic'
@@ -1227,10 +1381,16 @@ local compilerSwitch = util.switch()
end)
: case 'unary'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
+ if not source[1] then
+ return
+ end
if source.op.type == 'not' then
local result = vm.test(source[1])
if result == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'boolean'))
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
return
else
vm.setNode(source, {
@@ -1244,13 +1404,13 @@ local compilerSwitch = util.switch()
end
end
if source.op.type == '#' then
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
if source.op.type == '-' then
local v = vm.getNumber(source[1])
if v == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
else
vm.setNode(source, {
@@ -1266,7 +1426,7 @@ local compilerSwitch = util.switch()
if source.op.type == '~' then
local v = vm.getInteger(source[1])
if v == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
else
vm.setNode(source, {
@@ -1282,34 +1442,42 @@ local compilerSwitch = util.switch()
end)
: case 'binary'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
+ if not source[1] or not source[2] then
+ return
+ end
if source.op.type == 'and' then
+ local node1 = vm.compileNode(source[1])
+ local node2 = vm.compileNode(source[2])
local r1 = vm.test(source[1])
if r1 == true then
- vm.setNode(source, vm.compileNode(source[2]))
- return
- end
- if r1 == false then
- vm.setNode(source, vm.compileNode(source[1]))
- return
+ vm.setNode(source, node2)
+ elseif r1 == false then
+ vm.setNode(source, node1)
+ else
+ vm.setNode(source, node2)
end
- return
end
if source.op.type == 'or' then
+ local node1 = vm.compileNode(source[1])
+ local node2 = vm.compileNode(source[2])
local r1 = vm.test(source[1])
if r1 == true then
- vm.setNode(source, vm.compileNode(source[1]))
- return
- end
- if r1 == false then
- vm.setNode(source, vm.compileNode(source[2]))
- return
+ vm.setNode(source, node1)
+ elseif r1 == false then
+ vm.setNode(source, node2)
+ else
+ vm.getNode(source):merge(node1)
+ vm.getNode(source):setTruthy()
+ vm.getNode(source):merge(node2)
end
- return
end
if source.op.type == '==' then
local result = vm.equal(source[1], source[2])
if result == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'boolean'))
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
return
else
vm.setNode(source, {
@@ -1325,7 +1493,7 @@ local compilerSwitch = util.switch()
if source.op.type == '~=' then
local result = vm.equal(source[1], source[2])
if result == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'boolean'))
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
return
else
vm.setNode(source, {
@@ -1351,7 +1519,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1368,7 +1536,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1385,7 +1553,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1402,7 +1570,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1419,7 +1587,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1437,7 +1605,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1455,7 +1623,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1473,7 +1641,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1490,14 +1658,14 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
if source.op.type == '%' then
local a = vm.getNumber(source[1])
local b = vm.getNumber(source[2])
- if a and b then
+ if a and b and b ~= 0 then
local result = a % b
vm.setNode(source, {
type = math.type(result) == 'integer' and 'integer' or 'number',
@@ -1508,7 +1676,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1525,7 +1693,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1543,7 +1711,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1580,7 +1748,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'string'))
+ vm.setNode(source, vm.declareGlobal('type', 'string'))
return
end
end
@@ -1614,17 +1782,20 @@ local function compileByGlobal(source)
vm.setNode(source, globalNode, true)
return
end
+ ---@type vm.node
globalNode = vm.createNode(global)
vm.setNode(root._globalBase[name], globalNode, true)
+ vm.setNode(source, globalNode, true)
- local sets = global.links[uri].sets or {}
- local gets = global.links[uri].gets or {}
- for _, set in ipairs(sets) do
- vm.setNode(set, globalNode, true)
- end
- for _, get in ipairs(gets) do
- vm.setNode(get, globalNode, true)
- end
+ -- TODO:don't mix
+ --local sets = global.links[uri].sets or {}
+ --local gets = global.links[uri].gets or {}
+ --for _, set in ipairs(sets) do
+ -- vm.setNode(set, globalNode, true)
+ --end
+ --for _, get in ipairs(gets) do
+ -- vm.setNode(get, globalNode, true)
+ --end
if global.cate == 'variable' then
local hasMarkDoc
@@ -1672,7 +1843,11 @@ end
---@return vm.node
function vm.compileNode(source)
if not source then
- error('Can not compile nil node')
+ if TEST then
+ error('Can not compile nil source')
+ else
+ log.error('Can not compile nil source')
+ end
end
if source.type == 'global' then
diff --git a/script/vm/def.lua b/script/vm/def.lua
index b66e8fda..83e92686 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -2,8 +2,6 @@
local vm = require 'vm.vm'
local util = require 'utility'
local guide = require 'parser.guide'
-local localID = require 'vm.local-id'
-local globalMgr = require 'vm.global-manager'
local simpleSwitch
@@ -79,6 +77,13 @@ simpleSwitch = util.switch()
pushResult(source.node)
end
end)
+ : case 'doc.cast.name'
+ : call(function (source, pushResult)
+ local loc = guide.getLocal(source, source[1], source.start)
+ if loc then
+ pushResult(loc)
+ end
+ end)
local searchFieldSwitch = util.switch()
: case 'table'
@@ -97,7 +102,7 @@ local searchFieldSwitch = util.switch()
---@param key string
: call(function (suri, obj, key, pushResult)
if obj.cate == 'variable' then
- local newGlobal = globalMgr.getGlobal('variable', obj.name, key)
+ local newGlobal = vm.getGlobal('variable', obj.name, key)
if newGlobal then
for _, set in ipairs(newGlobal:getSets(suri)) do
pushResult(set)
@@ -110,7 +115,7 @@ local searchFieldSwitch = util.switch()
end)
: case 'local'
: call(function (suri, obj, key, pushResult)
- local sources = localID.getSources(obj, key)
+ local sources = vm.getLocalSources(obj, key)
if sources then
for _, src in ipairs(sources) do
if guide.isSet(src) then
@@ -189,7 +194,7 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
- local idSources = localID.getSources(source)
+ local idSources = vm.getLocalSources(source)
if not idSources then
return
end
diff --git a/script/vm/doc.lua b/script/vm/doc.lua
index 5a92a103..e2b383b6 100644
--- a/script/vm/doc.lua
+++ b/script/vm/doc.lua
@@ -3,7 +3,6 @@ local guide = require 'parser.guide'
---@class vm
local vm = require 'vm.vm'
local config = require 'config'
-local globalMgr = require 'vm.global-manager'
---获取class与alias
---@param suri uri
@@ -11,13 +10,13 @@ local globalMgr = require 'vm.global-manager'
---@return parser.object[]
function vm.getDocSets(suri, name)
if name then
- local global = globalMgr.getGlobal('type', name)
+ local global = vm.getGlobal('type', name)
if not global then
return {}
end
return global:getSets(suri)
else
- return globalMgr.getGlobalSets(suri, 'type')
+ return vm.getGlobalSets(suri, 'type')
end
end
@@ -27,6 +26,9 @@ function vm.isMetaFile(uri)
return false
end
local cache = files.getCache(uri)
+ if not cache then
+ return false
+ end
if cache.isMeta ~= nil then
return cache.isMeta
end
@@ -332,6 +334,9 @@ function vm.isDiagDisabledAt(uri, position, name)
return false
end
local cache = files.getCache(uri)
+ if not cache then
+ return false
+ end
if not cache.diagnosticRanges then
cache.diagnosticRanges = {}
for _, doc in ipairs(status.ast.docs) do
diff --git a/script/vm/field.lua b/script/vm/field.lua
index ba7cd4c1..5de838be 100644
--- a/script/vm/field.lua
+++ b/script/vm/field.lua
@@ -15,6 +15,15 @@ local searchByNodeSwitch = util.switch()
pushResult(source)
end)
+local function searchByLocalID(source, pushResult)
+ local fields = vm.getLocalFields(source)
+ if fields then
+ for _, field in ipairs(fields) do
+ pushResult(field)
+ end
+ end
+end
+
local function searchByNode(source, pushResult)
local uri = guide.getUri(source)
vm.compileByParentNode(source, nil, true, function (field)
@@ -35,6 +44,7 @@ function vm.getFields(source)
end
end
+ searchByLocalID(source, pushResult)
searchByNode(source, pushResult)
return results
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index b3981ff8..6462028e 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -1,3 +1,4 @@
+---@class vm
local vm = require 'vm.vm'
---@class parser.object
@@ -114,7 +115,7 @@ end
---@param uri uri
---@param args parser.object
----@return parser.object
+---@return vm.node
function mt:resolve(uri, args)
local resolved = self.sign:resolve(uri, args)
local protoNode = vm.compileNode(self.proto)
@@ -129,7 +130,7 @@ end
---@param proto vm.object
---@param sign vm.sign
---@return vm.generic
-return function (proto, sign)
+function vm.createGeneric(proto, sign)
local generic = setmetatable({
sign = sign,
proto = proto,
diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua
deleted file mode 100644
index f25bb5a0..00000000
--- a/script/vm/global-manager.lua
+++ /dev/null
@@ -1,364 +0,0 @@
-local util = require 'utility'
-local guide = require 'parser.guide'
-local globalBuilder = require 'vm.global'
-local signMgr = require 'vm.sign'
-local genericMgr = require 'vm.generic'
----@class vm
-local vm = require 'vm.vm'
-
----@class parser.object
----@field _globalNode vm.global
-
----@class vm.global-manager
-local m = {}
----@type table<string, vm.global>
-m.globals = {}
----@type table<uri, table<string, boolean>>
-m.globalSubs = util.multiTable(2)
-
-local compilerGlobalSwitch = util.switch()
- : case 'local'
- : call(function (source)
- if source.special ~= '_G' then
- return
- end
- if source.ref then
- for _, ref in ipairs(source.ref) do
- m.compileObject(ref)
- end
- end
- end)
- : case 'getlocal'
- : call(function (source)
- if source.special ~= '_G' then
- return
- end
- if not source.next then
- return
- end
- m.compileObject(source.next)
- end)
- : case 'setglobal'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = guide.getKeyName(source)
- local global = m.declareGlobal('variable', name, uri)
- global:addSet(uri, source)
- source._globalNode = global
- end)
- : case 'getglobal'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = guide.getKeyName(source)
- local global = m.declareGlobal('variable', name, uri)
- global:addGet(uri, source)
- source._globalNode = global
-
- local nxt = source.next
- if nxt then
- m.compileObject(nxt)
- end
- end)
- : case 'setfield'
- : case 'setmethod'
- : case 'setindex'
- ---@param source parser.object
- : call(function (source)
- local name
- local keyName = guide.getKeyName(source)
- if not keyName then
- return
- end
- if source.node._globalNode then
- local parentName = source.node._globalNode:getName()
- if parentName == '_G' then
- name = keyName
- else
- name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
- end
- elseif source.node.special == '_G' then
- name = keyName
- end
- if not name then
- return
- end
- local uri = guide.getUri(source)
- local global = m.declareGlobal('variable', name, uri)
- global:addSet(uri, source)
- source._globalNode = global
- end)
- : case 'getfield'
- : case 'getmethod'
- : case 'getindex'
- ---@param source parser.object
- : call(function (source)
- local name
- local keyName = guide.getKeyName(source)
- if not keyName then
- return
- end
- if source.node._globalNode then
- local parentName = source.node._globalNode:getName()
- if parentName == '_G' then
- name = keyName
- else
- name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
- end
- elseif source.node.special == '_G' then
- name = keyName
- end
- local uri = guide.getUri(source)
- local global = m.declareGlobal('variable', name, uri)
- global:addGet(uri, source)
- source._globalNode = global
-
- local nxt = source.next
- if nxt then
- m.compileObject(nxt)
- end
- end)
- : case 'call'
- : call(function (source)
- if source.node.special == 'rawset'
- or source.node.special == 'rawget' then
- if not source.args then
- return
- end
- local g = source.args[1]
- local key = source.args[2]
- if g and key and g.special == '_G' then
- local name = guide.getKeyName(key)
- if name then
- local uri = guide.getUri(source)
- local global = m.declareGlobal('variable', name, uri)
- if source.node.special == 'rawset' then
- global:addSet(uri, source)
- source.value = source.args[3]
- else
- global:addGet(uri, source)
- end
- source._globalNode = global
-
- local nxt = source.next
- if nxt then
- m.compileObject(nxt)
- end
- end
- end
- end
- end)
- : case 'doc.class'
- ---@param source parser.object
- : call(function (source)
- local uri = guide.getUri(source)
- local name = guide.getKeyName(source)
- local class = m.declareGlobal('type', name, uri)
- class:addSet(uri, source)
- source._globalNode = class
-
- if source.signs then
- source._sign = signMgr()
- for _, sign in ipairs(source.signs) do
- source._sign:addSign(vm.compileNode(sign))
- end
- if source.extends then
- for _, ext in ipairs(source.extends) do
- if ext.type == 'doc.type.table' then
- ext._generic = genericMgr(ext, source._sign)
- end
- end
- end
- end
- end)
- : case 'doc.alias'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = guide.getKeyName(source)
- local alias = m.declareGlobal('type', name, uri)
- alias:addSet(uri, source)
- source._globalNode = alias
-
- if source.signs then
- source._sign = signMgr()
- for _, sign in ipairs(source.signs) do
- source._sign:addSign(vm.compileNode(sign))
- end
- source.extends._generic = genericMgr(source.extends, source._sign)
- end
- end)
- : case 'doc.type.name'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = source[1]
- local type = m.declareGlobal('type', name, uri)
- type:addGet(uri, source)
- source._globalNode = type
- end)
- : case 'doc.extends.name'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = source[1]
- local class = m.declareGlobal('type', name, uri)
- class:addGet(uri, source)
- source._globalNode = class
- end)
-
-
----@alias vm.global.cate '"variable"' | '"type"'
-
----@param cate vm.global.cate
----@param name string
----@param uri uri
----@return vm.global
-function m.declareGlobal(cate, name, uri)
- local key = cate .. '|' .. name
- m.globalSubs[uri][key] = true
- if not m.globals[key] then
- m.globals[key] = globalBuilder(name, cate)
- end
- return m.globals[key]
-end
-
----@param cate vm.global.cate
----@param name string
----@param field? string
----@return vm.global?
-function m.getGlobal(cate, name, field)
- local key = cate .. '|' .. name
- if field then
- key = key .. vm.ID_SPLITE .. field
- end
- return m.globals[key]
-end
-
----@param cate vm.global.cate
----@param name string
----@return vm.global[]
-function m.getFields(cate, name)
- local globals = {}
- local key = cate .. '|' .. name
-
- -- TODO: optimize
- local clock = os.clock()
- for gid, global in pairs(m.globals) do
- if gid ~= key
- and util.stringStartWith(gid, key)
- and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE
- and not gid:find(vm.ID_SPLITE, #key + 2) then
- globals[#globals+1] = global
- end
- end
- local cost = os.clock() - clock
- if cost > 0.1 then
- log.warn('global-manager getFields cost %.3f', cost)
- end
-
- return globals
-end
-
----@param cate vm.global.cate
----@return vm.global[]
-function m.getGlobals(cate)
- local globals = {}
-
- -- TODO: optimize
- local clock = os.clock()
- for gid, global in pairs(m.globals) do
- if util.stringStartWith(gid, cate)
- and not gid:find(vm.ID_SPLITE) then
- globals[#globals+1] = global
- end
- end
- local cost = os.clock() - clock
- if cost > 0.1 then
- log.warn('global-manager getGlobals cost %.3f', cost)
- end
-
- return globals
-end
-
----@param suri uri
----@param cate vm.global.cate
----@return parser.object[]
-function m.getGlobalSets(suri, cate)
- local globals = m.getGlobals(cate)
- local result = {}
- for _, global in ipairs(globals) do
- local sets = global:getSets(suri)
- for _, set in ipairs(sets) do
- result[#result+1] = set
- end
- end
- return result
-end
-
----@param suri uri
----@param cate vm.global.cate
----@param name string
----@return boolean
-function m.hasGlobalSets(suri, cate, name)
- local global = m.getGlobal(cate, name)
- if not global then
- return false
- end
- local sets = global:getSets(suri)
- if #sets == 0 then
- return false
- end
- return true
-end
-
----@param source parser.object
-function m.compileObject(source)
- if source._globalNode ~= nil then
- return
- end
- source._globalNode = false
- compilerGlobalSwitch(source.type, source)
-end
-
----@param source parser.object
-function m.compileAst(source)
- local env = guide.getENV(source)
- m.compileObject(env)
- guide.eachSpecialOf(source, 'rawset', function (src)
- m.compileObject(src.parent)
- end)
- guide.eachSpecialOf(source, 'rawget', function (src)
- m.compileObject(src.parent)
- end)
- guide.eachSourceTypes(source.docs, {
- 'doc.class',
- 'doc.alias',
- 'doc.type.name',
- 'doc.extends.name',
- }, function (src)
- m.compileObject(src)
- end)
-end
-
----@return vm.global
-function m.getNode(source)
- if source.type == 'field'
- or source.type == 'method' then
- source = source.parent
- end
- return source._globalNode
-end
-
----@param uri uri
-function m.dropUri(uri)
- local globalSub = m.globalSubs[uri]
- m.globalSubs[uri] = nil
- for key in pairs(globalSub) do
- local global = m.globals[key]
- if global then
- global:dropUri(uri)
- if not global:isAlive() then
- m.globals[key] = nil
- end
- end
- end
-end
-
-return m
diff --git a/script/vm/global.lua b/script/vm/global.lua
index 1c46c9a3..a54ab552 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -1,5 +1,9 @@
-local util = require 'utility'
-local scope= require 'workspace.scope'
+local util = require 'utility'
+local scope = require 'workspace.scope'
+local guide = require 'parser.guide'
+local files = require 'files'
+---@class vm
+local vm = require 'vm.vm'
---@class vm.global.link
---@field gets parser.object[]
@@ -15,8 +19,6 @@ mt.__index = mt
mt.type = 'global'
mt.name = ''
-local ID_SPLITE = '\x1F'
-
---@param uri uri
---@param source parser.object
function mt:addSet(uri, source)
@@ -106,7 +108,7 @@ end
---@return string
function mt:getKeyName()
- return self.name:match('[^' .. ID_SPLITE .. ']+$')
+ return self.name:match('[^' .. vm.ID_SPLITE .. ']+$')
end
---@return boolean
@@ -116,10 +118,427 @@ end
---@param cate vm.global.cate
---@return vm.global
-return function (name, cate)
+local function createGlobal(name, cate)
return setmetatable({
name = name,
cate = cate,
links = util.multiTable(2),
}, mt)
end
+
+---@class parser.object
+---@field _globalNode vm.global
+
+---@type table<string, vm.global>
+local allGlobals = {}
+---@type table<uri, table<string, boolean>>
+local globalSubs = util.multiTable(2)
+
+local compileObject
+local compilerGlobalSwitch = util.switch()
+ : case 'local'
+ : call(function (source)
+ if source.special ~= '_G' then
+ return
+ end
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ compileObject(ref)
+ end
+ end
+ end)
+ : case 'getlocal'
+ : call(function (source)
+ if source.special ~= '_G' then
+ return
+ end
+ if not source.next then
+ return
+ end
+ compileObject(source.next)
+ end)
+ : case 'setglobal'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ global:addSet(uri, source)
+ source._globalNode = global
+ end)
+ : case 'getglobal'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ global:addGet(uri, source)
+ source._globalNode = global
+
+ local nxt = source.next
+ if nxt then
+ compileObject(nxt)
+ end
+ end)
+ : case 'setfield'
+ : case 'setmethod'
+ : case 'setindex'
+ ---@param source parser.object
+ : call(function (source)
+ local name
+ local keyName = guide.getKeyName(source)
+ if not keyName then
+ return
+ end
+ if source.node._globalNode then
+ local parentName = source.node._globalNode:getName()
+ if parentName == '_G' then
+ name = keyName
+ else
+ name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
+ end
+ elseif source.node.special == '_G' then
+ name = keyName
+ end
+ if not name then
+ return
+ end
+ local uri = guide.getUri(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ global:addSet(uri, source)
+ source._globalNode = global
+ end)
+ : case 'getfield'
+ : case 'getmethod'
+ : case 'getindex'
+ ---@param source parser.object
+ : call(function (source)
+ local name
+ local keyName = guide.getKeyName(source)
+ if not keyName then
+ return
+ end
+ if source.node._globalNode then
+ local parentName = source.node._globalNode:getName()
+ if parentName == '_G' then
+ name = keyName
+ else
+ name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
+ end
+ elseif source.node.special == '_G' then
+ name = keyName
+ end
+ local uri = guide.getUri(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ global:addGet(uri, source)
+ source._globalNode = global
+
+ local nxt = source.next
+ if nxt then
+ compileObject(nxt)
+ end
+ end)
+ : case 'call'
+ : call(function (source)
+ if source.node.special == 'rawset'
+ or source.node.special == 'rawget' then
+ if not source.args then
+ return
+ end
+ local g = source.args[1]
+ local key = source.args[2]
+ if g and key and g.special == '_G' then
+ local name = guide.getKeyName(key)
+ if name then
+ local uri = guide.getUri(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ if source.node.special == 'rawset' then
+ global:addSet(uri, source)
+ source.value = source.args[3]
+ else
+ global:addGet(uri, source)
+ end
+ source._globalNode = global
+
+ local nxt = source.next
+ if nxt then
+ compileObject(nxt)
+ end
+ end
+ end
+ end
+ end)
+ : case 'doc.class'
+ ---@param source parser.object
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ local class = vm.declareGlobal('type', name, uri)
+ class:addSet(uri, source)
+ source._globalNode = class
+
+ if source.signs then
+ source._sign = vm.createSign()
+ for _, sign in ipairs(source.signs) do
+ source._sign:addSign(vm.compileNode(sign))
+ end
+ if source.extends then
+ for _, ext in ipairs(source.extends) do
+ if ext.type == 'doc.type.table' then
+ ext._generic = vm.createGeneric(ext, source._sign)
+ end
+ end
+ end
+ end
+ end)
+ : case 'doc.alias'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ local alias = vm.declareGlobal('type', name, uri)
+ alias:addSet(uri, source)
+ source._globalNode = alias
+
+ if source.signs then
+ source._sign = vm.createSign()
+ for _, sign in ipairs(source.signs) do
+ source._sign:addSign(vm.compileNode(sign))
+ end
+ source.extends._generic = vm.createGeneric(source.extends, source._sign)
+ end
+ end)
+ : case 'doc.type.name'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = source[1]
+ local type = vm.declareGlobal('type', name, uri)
+ type:addGet(uri, source)
+ source._globalNode = type
+ end)
+ : case 'doc.extends.name'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = source[1]
+ local class = vm.declareGlobal('type', name, uri)
+ class:addGet(uri, source)
+ source._globalNode = class
+ end)
+
+
+---@alias vm.global.cate '"variable"' | '"type"'
+
+---@param cate vm.global.cate
+---@param name string
+---@param uri? uri
+---@return vm.global
+function vm.declareGlobal(cate, name, uri)
+ local key = cate .. '|' .. name
+ if uri then
+ globalSubs[uri][key] = true
+ end
+ if not allGlobals[key] then
+ allGlobals[key] = createGlobal(name, cate)
+ end
+ return allGlobals[key]
+end
+
+---@param cate vm.global.cate
+---@param name string
+---@param field? string
+---@return vm.global?
+function vm.getGlobal(cate, name, field)
+ local key = cate .. '|' .. name
+ if field then
+ key = key .. vm.ID_SPLITE .. field
+ end
+ return allGlobals[key]
+end
+
+---@param cate vm.global.cate
+---@param name string
+---@return vm.global[]
+function vm.getGlobalFields(cate, name)
+ local globals = {}
+ local key = cate .. '|' .. name
+
+ local clock = os.clock()
+ for gid, global in pairs(allGlobals) do
+ if gid ~= key
+ and util.stringStartWith(gid, key)
+ and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE
+ and not gid:find(vm.ID_SPLITE, #key + 2) then
+ globals[#globals+1] = global
+ end
+ end
+ local cost = os.clock() - clock
+ if cost > 0.1 then
+ log.warn('global-manager getFields cost %.3f', cost)
+ end
+
+ return globals
+end
+
+---@param cate vm.global.cate
+---@return vm.global[]
+function vm.getGlobals(cate)
+ local globals = {}
+
+ local clock = os.clock()
+ for gid, global in pairs(allGlobals) do
+ if util.stringStartWith(gid, cate)
+ and not gid:find(vm.ID_SPLITE) then
+ globals[#globals+1] = global
+ end
+ end
+ local cost = os.clock() - clock
+ if cost > 0.1 then
+ log.warn('global-manager getGlobals cost %.3f', cost)
+ end
+
+ return globals
+end
+
+---@param suri uri
+---@param cate vm.global.cate
+---@return parser.object[]
+function vm.getGlobalSets(suri, cate)
+ local globals = vm.getGlobals(cate)
+ local result = {}
+ for _, global in ipairs(globals) do
+ local sets = global:getSets(suri)
+ for _, set in ipairs(sets) do
+ result[#result+1] = set
+ end
+ end
+ return result
+end
+
+---@param suri uri
+---@param cate vm.global.cate
+---@param name string
+---@return boolean
+function vm.hasGlobalSets(suri, cate, name)
+ local global = vm.getGlobal(cate, name)
+ if not global then
+ return false
+ end
+ local sets = global:getSets(suri)
+ if #sets == 0 then
+ return false
+ end
+ return true
+end
+
+---@param source parser.object
+function compileObject(source)
+ if source._globalNode ~= nil then
+ return
+ end
+ source._globalNode = false
+ compilerGlobalSwitch(source.type, source)
+end
+
+---@param source parser.object
+local function compileSelf(source)
+ if source.parent.type ~= 'funcargs' then
+ return
+ end
+ ---@type parser.object
+ local node = source.parent.parent and source.parent.parent.parent and source.parent.parent.parent.node
+ if not node then
+ return
+ end
+ local fields = vm.getLocalFields(source)
+ if not fields then
+ return
+ end
+ local nodeLocalID = vm.getLocalID(node)
+ local globalNode = node._globalNode
+ if not nodeLocalID and not globalNode then
+ return
+ end
+ for _, field in ipairs(fields) do
+ if field.type == 'setfield' then
+ local key = guide.getKeyName(field)
+ if key then
+ if nodeLocalID then
+ local myID = nodeLocalID .. vm.ID_SPLITE .. key
+ vm.insertLocalID(myID, field)
+ end
+ if globalNode then
+ local myID = globalNode:getName() .. vm.ID_SPLITE .. key
+ local myGlobal = vm.declareGlobal('variable', myID, guide.getUri(node))
+ myGlobal:addSet(guide.getUri(node), field)
+ end
+ end
+ end
+ end
+end
+
+---@param source parser.object
+local function compileAst(source)
+ local env = guide.getENV(source)
+ if not env then
+ return
+ end
+ compileObject(env)
+ guide.eachSpecialOf(source, 'rawset', function (src)
+ compileObject(src.parent)
+ end)
+ guide.eachSpecialOf(source, 'rawget', function (src)
+ compileObject(src.parent)
+ end)
+ guide.eachSourceTypes(source.docs, {
+ 'doc.class',
+ 'doc.alias',
+ 'doc.type.name',
+ 'doc.extends.name',
+ }, function (src)
+ compileObject(src)
+ end)
+
+ --[[
+ local mt
+ function mt:xxx()
+ self.a = 1
+ end
+
+ mt.a --> find this definition
+ ]]
+ guide.eachSourceType(source, 'self', function (src)
+ compileSelf(src)
+ end)
+end
+
+---@param uri uri
+local function dropUri(uri)
+ local globalSub = globalSubs[uri]
+ globalSubs[uri] = nil
+ for key in pairs(globalSub) do
+ local global = allGlobals[key]
+ if global then
+ global:dropUri(uri)
+ if not global:isAlive() then
+ allGlobals[key] = nil
+ end
+ end
+ end
+end
+
+for uri in files.eachFile() do
+ local state = files.getState(uri)
+ if state then
+ compileAst(state.ast)
+ end
+end
+
+files.watch(function (ev, uri)
+ if ev == 'update' then
+ dropUri(uri)
+ local state = files.getState(uri)
+ if state then
+ compileAst(state.ast)
+ end
+ end
+ if ev == 'remove' then
+ dropUri(uri)
+ end
+end)
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 2a64ed52..fabc9828 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -1,11 +1,9 @@
local util = require 'utility'
local config = require 'config'
local guide = require 'parser.guide'
+---@class vm
local vm = require 'vm.vm'
----@class vm.infer-manager
-local m = {}
-
---@class vm.infer
---@field views table<string, boolean>
---@field cachedView? string
@@ -21,7 +19,7 @@ mt._hasDocFunction = false
mt._isParam = false
mt._isLocal = false
-m.NULL = setmetatable({}, mt)
+vm.NULL = setmetatable({}, mt)
local inferSorted = {
['boolean'] = - 100,
@@ -52,7 +50,7 @@ local viewNodeSwitch = util.switch()
: call(function (source, infer)
if source.type == 'table' then
if #source == 1 and source[1].type == 'varargs' then
- local node = m.getInfer(source[1]):view()
+ local node = vm.getInfer(source[1]):view()
return ('%s[]'):format(node)
end
end
@@ -90,7 +88,7 @@ local viewNodeSwitch = util.switch()
if source.signs then
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = m.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view()
end
return ('%s<%s>'):format(source[1], table.concat(buf, ', '))
else
@@ -99,7 +97,7 @@ local viewNodeSwitch = util.switch()
end)
: case 'generic'
: call(function (source, infer)
- return m.getInfer(source.proto):view()
+ return vm.getInfer(source.proto):view()
end)
: case 'doc.generic.name'
: call(function (source, infer)
@@ -108,7 +106,7 @@ local viewNodeSwitch = util.switch()
: case 'doc.type.array'
: call(function (source, infer)
infer._hasClass = true
- local view = m.getInfer(source.node):view()
+ local view = vm.getInfer(source.node):view()
if source.node.type == 'doc.type' then
view = '(' .. view .. ')'
end
@@ -119,7 +117,7 @@ local viewNodeSwitch = util.switch()
infer._hasClass = true
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = m.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view()
end
return ('%s<%s>'):format(source.node[1], table.concat(buf, ', '))
end)
@@ -144,20 +142,23 @@ local viewNodeSwitch = util.switch()
local argView = ''
local regView = ''
for i, arg in ipairs(source.args) do
+ local argNode = vm.compileNode(arg)
+ local isOptional = argNode:isOptional()
+ if isOptional then
+ argNode = argNode:copy()
+ argNode:removeOptional()
+ end
args[i] = string.format('%s%s: %s'
, arg.name[1]
- , arg.optional and '?' or ''
- , m.getInfer(arg):view()
+ , isOptional and '?' or ''
+ , vm.getInfer(argNode):view()
)
end
if #args > 0 then
argView = table.concat(args, ', ')
end
for i, ret in ipairs(source.returns) do
- rets[i] = string.format('%s%s'
- , m.getInfer(ret):view()
- , ret.optional and '?' or ''
- )
+ rets[i] = vm.getInfer(ret):view()
end
if #rets > 0 then
regView = ':' .. table.concat(rets, ', ')
@@ -165,16 +166,21 @@ local viewNodeSwitch = util.switch()
return ('fun(%s)%s'):format(argView, regView)
end)
----@param source parser.object
+---@param source parser.object | vm.node
---@return vm.infer
-function m.getInfer(source)
- local node = vm.compileNode(source)
+function vm.getInfer(source)
+ local node
+ if source.type == 'vm.node' then
+ node = source
+ else
+ node = vm.compileNode(source)
+ end
if node.lastInfer then
return node.lastInfer
end
local infer = setmetatable({
node = node,
- uri = guide.getUri(source),
+ uri = source.type ~= 'vm.node' and guide.getUri(source),
}, mt)
node.lastInfer = infer
@@ -199,24 +205,24 @@ function mt:_trim()
if self._hasTable and not self._hasClass then
self.views['table'] = true
end
- if self._hasClass then
- self:_eraseAlias()
- end
end
-function mt:_eraseAlias()
- local expandAlias = config.get(self.uri, 'Lua.hover.expandAlias')
+---@param uri uri
+---@return table<string, true>
+function mt:_eraseAlias(uri)
+ local drop = {}
+ local expandAlias = config.get(uri, 'Lua.hover.expandAlias')
for n in self.node:eachObject() do
if n.type == 'global' and n.cate == 'type' then
- for _, set in ipairs(n:getSets(self.uri)) do
+ for _, set in ipairs(n:getSets(uri)) do
if set.type == 'doc.alias' then
if expandAlias then
- self.views[n.name] = nil
+ drop[n.name] = true
else
for _, ext in ipairs(set.extends.types) do
local view = viewNodeSwitch(ext.type, ext, {})
if view and view ~= n.name then
- self.views[view] = nil
+ drop[view] = true
end
end
end
@@ -224,6 +230,7 @@ function mt:_eraseAlias()
end
end
end
+ return drop
end
---@param tp string
@@ -273,17 +280,16 @@ function mt:view(default, uri)
return 'any'
end
- if not next(self.views) then
- return default or 'unknown'
- end
-
- if self.cachedView then
- return self.cachedView
+ local drop
+ if self._hasClass then
+ drop = self:_eraseAlias(uri or self.uri)
end
local array = {}
for view in pairs(self.views) do
- array[#array+1] = view
+ if not drop or not drop[view] then
+ array[#array+1] = view
+ end
end
table.sort(array, function (a, b)
@@ -298,22 +304,29 @@ function mt:view(default, uri)
local max = #array
local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit')
- if max > limit then
- local view = string.format('%s...(+%d)'
- , table.concat(array, '|', 1, limit)
- , max - limit
- )
-
- self.cachedView = view
-
- return view
+ local view
+ if #array == 0 then
+ view = default or 'unknown'
else
- local view = table.concat(array, '|')
-
- self.cachedView = view
+ if max > limit then
+ view = string.format('%s...(+%d)'
+ , table.concat(array, '|', 1, limit)
+ , max - limit
+ )
+ else
+ view = table.concat(array, '|')
+ end
+ end
- return view
+ if self.node:isOptional() then
+ if max > 1 then
+ view = '(' .. view .. ')?'
+ else
+ view = view .. '?'
+ end
end
+
+ return view
end
function mt:eachView()
@@ -324,10 +337,10 @@ end
---@param other vm.infer
---@return vm.infer
function mt:merge(other)
- if self == m.NULL then
+ if self == vm.NULL then
return other
end
- if other == m.NULL then
+ if other == vm.NULL then
return self
end
@@ -390,8 +403,6 @@ end
---@param source parser.object
---@return string?
-function m.viewObject(source)
+function vm.viewObject(source)
return viewNodeSwitch(source.type, source, {})
end
-
-return m
diff --git a/script/vm/init.lua b/script/vm/init.lua
index 0058c698..f5003c11 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -1,4 +1,7 @@
local vm = require 'vm.vm'
+
+---@alias vm.object parser.object | vm.global | vm.generic
+
require 'vm.compiler'
require 'vm.value'
require 'vm.node'
@@ -8,5 +11,10 @@ require 'vm.field'
require 'vm.doc'
require 'vm.type'
require 'vm.library'
-require 'vm.manager'
+require 'vm.runner'
+require 'vm.infer'
+require 'vm.generic'
+require 'vm.sign'
+require 'vm.local-id'
+require 'vm.global'
return vm
diff --git a/script/vm/library.lua b/script/vm/library.lua
index 49f7adb0..e7bf4f42 100644
--- a/script/vm/library.lua
+++ b/script/vm/library.lua
@@ -13,24 +13,3 @@ function vm.getLibraryName(source)
end
return nil
end
-
-local globalLibraryNames = {
- 'arg', 'assert', 'error', 'collectgarbage', 'dofile', '_G', 'getfenv',
- 'getmetatable', 'ipairs', 'load', 'loadfile', 'loadstring',
- 'module', 'next', 'pairs', 'pcall', 'print', 'rawequal',
- 'rawget', 'rawlen', 'rawset', 'select', 'setfenv',
- 'setmetatable', 'tonumber', 'tostring', 'type', '_VERSION',
- 'warn', 'xpcall', 'require', 'unpack', 'bit32', 'coroutine',
- 'debug', 'io', 'math', 'os', 'package', 'string', 'table',
- 'utf8', 'newproxy',
-}
-local globalLibraryNamesMap
-function vm.isGlobalLibraryName(name)
- if not globalLibraryNamesMap then
- globalLibraryNamesMap = {}
- for _, v in ipairs(globalLibraryNames) do
- globalLibraryNamesMap[v] = true
- end
- end
- return globalLibraryNamesMap[name] or false
-end
diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua
index 728de301..80c68769 100644
--- a/script/vm/local-id.lua
+++ b/script/vm/local-id.lua
@@ -1,13 +1,13 @@
local util = require 'utility'
local guide = require 'parser.guide'
+---@class vm
local vm = require 'vm.vm'
---@class parser.object
---@field _localID string
---@field _localIDs table<string, parser.object[]>
----@class vm.local-id
-local m = {}
+local compileLocalID, getLocal
local compileSwitch = util.switch()
: case 'local'
@@ -18,13 +18,13 @@ local compileSwitch = util.switch()
return
end
for _, ref in ipairs(source.ref) do
- m.compileLocalID(ref)
+ compileLocalID(ref)
end
end)
: case 'getlocal'
: call(function (source)
source._localID = ('%d'):format(source.node.start)
- m.compileLocalID(source.next)
+ compileLocalID(source.next)
end)
: case 'getfield'
: case 'setfield'
@@ -40,7 +40,7 @@ local compileSwitch = util.switch()
source._localID = parentID .. vm.ID_SPLITE .. key
source.field._localID = source._localID
if source.type == 'getfield' then
- m.compileLocalID(source.next)
+ compileLocalID(source.next)
end
end)
: case 'getmethod'
@@ -57,7 +57,7 @@ local compileSwitch = util.switch()
source._localID = parentID .. vm.ID_SPLITE .. key
source.method._localID = source._localID
if source.type == 'getmethod' then
- m.compileLocalID(source.next)
+ compileLocalID(source.next)
end
end)
: case 'getindex'
@@ -74,7 +74,7 @@ local compileSwitch = util.switch()
source._localID = parentID .. vm.ID_SPLITE .. key
source.index._localID = source._localID
if source.type == 'setindex' then
- m.compileLocalID(source.next)
+ compileLocalID(source.next)
end
end)
@@ -82,7 +82,7 @@ local leftSwitch = util.switch()
: case 'field'
: case 'method'
: call(function (source)
- return m.getLocal(source.parent)
+ return getLocal(source.parent)
end)
: case 'getfield'
: case 'setfield'
@@ -91,24 +91,36 @@ local leftSwitch = util.switch()
: case 'getindex'
: case 'setindex'
: call(function (source)
- return m.getLocal(source.node)
+ return getLocal(source.node)
end)
: case 'getlocal'
: call(function (source)
return source.node
end)
: case 'local'
+ : case 'self'
: call(function (source)
return source
end)
---@param source parser.object
---@return parser.object?
-function m.getLocal(source)
+function getLocal(source)
return leftSwitch(source.type, source)
end
-function m.compileLocalID(source)
+---@param id string
+---@param source parser.object
+function vm.insertLocalID(id, source)
+ local root = guide.getRoot(source)
+ if not root._localIDs then
+ root._localIDs = util.multiTable(2)
+ end
+ local sources = root._localIDs[id]
+ sources[#sources+1] = source
+end
+
+function compileLocalID(source)
if not source then
return
end
@@ -117,37 +129,33 @@ function m.compileLocalID(source)
return
end
compileSwitch(source.type, source)
- if not source._localID then
+ local id = source._localID
+ if not id then
return
end
- local root = guide.getRoot(source)
- if not root._localIDs then
- root._localIDs = util.multiTable(2)
- end
- local sources = root._localIDs[source._localID]
- sources[#sources+1] = source
+ vm.insertLocalID(id, source)
end
---@param source parser.object
----@return string|boolean
-function m.getID(source)
+---@return string?
+function vm.getLocalID(source)
if source._localID ~= nil then
return source._localID
end
source._localID = false
- local loc = m.getLocal(source)
+ local loc = getLocal(source)
if not loc then
return source._localID
end
- m.compileLocalID(loc)
+ compileLocalID(loc)
return source._localID
end
---@param source parser.object
---@param key? string
---@return parser.object[]?
-function m.getSources(source, key)
- local id = m.getID(source)
+function vm.getLocalSources(source, key)
+ local id = vm.getLocalID(source)
if not id then
return nil
end
@@ -166,8 +174,8 @@ end
---@param source parser.object
---@return parser.object[]
-function m.getFields(source)
- local id = m.getID(source)
+function vm.getLocalFields(source)
+ local id = vm.getLocalID(source)
if not id then
return nil
end
@@ -195,5 +203,3 @@ function m.getFields(source)
end
return fields
end
-
-return m
diff --git a/script/vm/local-manager.lua b/script/vm/local-manager.lua
deleted file mode 100644
index 51bafb24..00000000
--- a/script/vm/local-manager.lua
+++ /dev/null
@@ -1,40 +0,0 @@
-local util = require 'utility'
-local guide = require 'parser.guide'
-
----@class vm.local-node
-local m = {}
----@type table<uri, parser.object[]>
-m.locals = util.multiTable(2)
----@type table<parser.object, table<parser.object, boolean>>
-m.localSubs = util.multiTable(2, function ()
- return setmetatable({}, util.MODE_K)
-end)
----@type table<parser.object, boolean>
-m.allLocals = {}
-
----@param source parser.object
-function m.declareLocal(source)
- if m.allLocals[source] then
- return
- end
- m.allLocals[source] = true
- local uri = guide.getUri(source)
- local locals = m.locals[uri]
- locals[#locals+1] = source
-end
-
----@param uri uri
-function m.dropUri(uri)
- local locals = m.locals[uri]
- m.locals[uri] = nil
- for _, loc in ipairs(locals) do
- m.allLocals[loc] = nil
- local localSubs = m.localSubs[loc]
- m.localSubs[loc] = nil
- for source in pairs(localSubs) do
- source._node = nil
- end
- end
-end
-
-return m
diff --git a/script/vm/manager.lua b/script/vm/manager.lua
deleted file mode 100644
index 58255fca..00000000
--- a/script/vm/manager.lua
+++ /dev/null
@@ -1,26 +0,0 @@
-
-local files = require 'files'
-local globalManager = require 'vm.global-manager'
-local localManager = require 'vm.local-manager'
-
----@alias vm.object parser.object | vm.global | vm.generic
-
----@class vm.state
-local m = {}
-
-files.watch(function (ev, uri)
- if ev == 'update' then
- globalManager.dropUri(uri)
- localManager.dropUri(uri)
- local state = files.getState(uri)
- if state then
- globalManager.compileAst(state.ast)
- end
- end
- if ev == 'remove' then
- globalManager.dropUri(uri)
- localManager.dropUri(uri)
- end
-end)
-
-return m
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 6906da7e..e76542aa 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -1,5 +1,4 @@
local files = require 'files'
-local localMgr = require 'vm.local-manager'
---@class vm
local vm = require 'vm.vm'
local ws = require 'workspace.workspace'
@@ -8,15 +7,14 @@ local ws = require 'workspace.workspace'
vm.nodeCache = {}
---@class vm.node
+---@field [integer] vm.object
local mt = {}
mt.__index = mt
+mt.id = 0
mt.type = 'vm.node'
mt.optional = nil
mt.lastInfer = nil
mt.data = nil
----@type vm.node[]
-mt._childs = nil
-mt._locked = false
---@param node vm.node | vm.object
function mt:merge(node)
@@ -30,20 +28,10 @@ function mt:merge(node)
if node:isOptional() then
self.optional = true
end
- if node._locked then
- if not self._childs then
- self._childs = {}
- end
- if not self._childs[node] then
- self._childs[#self._childs+1] = node
- self._childs[node] = true
- end
- else
- for _, obj in ipairs(node) do
- if not self[obj] then
- self[obj] = true
- self[#self+1] = obj
- end
+ for _, obj in ipairs(node) do
+ if not self[obj] then
+ self[obj] = true
+ self[#self+1] = obj
end
end
else
@@ -54,84 +42,25 @@ function mt:merge(node)
end
end
-function mt:_each(mark, callback)
- if mark[self] then
- return
- end
- mark[self] = true
- for i = 1, #self do
- callback(self[i])
- end
- local childs = self._childs
- if not childs then
- return
- end
- for i = 1, #childs do
- local child = childs[i]
- if not child:isLocked() then
- child:_each(mark, callback)
- end
- end
-end
-
-function mt:_expand()
- local childs = self._childs
- if not childs then
- return
- end
- self._childs = nil
-
- local mark = {}
- mark[self] = true
-
- local function insert(obj)
- if not self[obj] then
- self[obj] = true
- self[#self+1] = obj
- end
- end
-
- for i = 1, #childs do
- local child = childs[i]
- if child:isLocked() then
- if not self._childs then
- self._childs = {}
- end
- if not self._childs[child] then
- self._childs[#self._childs+1] = child
- self._childs[child] = true
- end
- else
- child:_each(mark, insert)
- end
- end
-end
-
---@return boolean
function mt:isEmpty()
- self:_expand()
return #self == 0
end
+function mt:clear()
+ self.optional = nil
+ for i, c in ipairs(self) do
+ self[i] = nil
+ self[c] = nil
+ end
+end
+
---@param n integer
---@return vm.object?
function mt:get(n)
- self:_expand()
return self[n]
end
-function mt:lock()
- self._locked = true
-end
-
-function mt:unlock()
- self._locked = false
-end
-
-function mt:isLocked()
- return self._locked == true
-end
-
function mt:setData(k, v)
if not self.data then
self.data = {}
@@ -147,49 +76,143 @@ function mt:getData(k)
end
function mt:addOptional()
- if self:isOptional() then
- return self
- end
self.optional = true
end
function mt:removeOptional()
- if not self:isOptional() then
- return self
- end
- self:_expand()
- for i = #self, 1, -1 do
- local n = self[i]
- if n.type == 'nil'
- or (n.type == 'boolean' and n[1] == false)
- or (n.type == 'doc.type.boolean' and n[1] == false) then
- self[i] = self[#self]
- self[#self] = nil
- end
- end
+ self:remove 'nil'
end
---@return boolean
function mt:isOptional()
- if self.optional ~= nil then
- return self.optional
+ return self.optional == true
+end
+
+---@return boolean
+function mt:hasFalsy()
+ if self.optional then
+ return true
end
- self:_expand()
for _, c in ipairs(self) do
if c.type == 'nil'
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
or (c.type == 'boolean' and c[1] == false)
or (c.type == 'doc.type.boolean' and c[1] == false) then
- self.optional = true
return true
end
end
- self.optional = false
return false
end
+---@return boolean
+function mt:isNullable()
+ if self.optional then
+ return true
+ end
+ if #self == 0 then
+ return true
+ end
+ for _, c in ipairs(self) do
+ if c.type == 'nil'
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'any') then
+ return true
+ end
+ end
+ return false
+end
+
+---@return vm.node
+function mt:setTruthy()
+ if self.optional == true then
+ self.optional = nil
+ end
+ local hasBoolean
+ for index = #self, 1, -1 do
+ local c = self[index]
+ if c.type == 'nil'
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
+ or (c.type == 'boolean' and c[1] == false)
+ or (c.type == 'doc.type.boolean' and c[1] == false) then
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
+ end
+ if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean')
+ or (c.type == 'boolean' or c.type == 'doc.type.boolean') then
+ hasBoolean = true
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
+ end
+ ::CONTINUE::
+ end
+ if hasBoolean then
+ self[#self+1] = vm.declareGlobal('type', 'true')
+ end
+end
+
+---@return vm.node
+function mt:setFalsy()
+ if self.optional == false then
+ self.optional = nil
+ end
+ local hasBoolean
+ for index = #self, 1, -1 do
+ local c = self[index]
+ if c.type == 'nil'
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
+ or (c.type == 'boolean' and c[1] == true)
+ or (c.type == 'doc.type.boolean' and c[1] == true) then
+ goto CONTINUE
+ end
+ if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean')
+ or (c.type == 'boolean' or c.type == 'doc.type.boolean') then
+ hasBoolean = true
+ table.remove(self, index)
+ self[c] = nil
+ end
+ ::CONTINUE::
+ end
+ if hasBoolean then
+ self[#self+1] = vm.declareGlobal('type', 'false')
+ end
+end
+
+---@param name string
+function mt:remove(name)
+ if name == 'nil' and self.optional == true then
+ self.optional = nil
+ end
+ for index = #self, 1, -1 do
+ local c = self[index]
+ if (c.type == 'global' and c.cate == 'type' and c.name == name)
+ or (c.type == name)
+ or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer'))
+ or (c.type == 'doc.type.boolean' and name == 'boolean')
+ or (c.type == 'doc.type.table' and name == 'table')
+ or (c.type == 'doc.type.array' and name == 'table')
+ or (c.type == 'doc.type.function' and name == 'function') then
+ table.remove(self, index)
+ self[c] = nil
+ end
+ end
+end
+
+---@param node vm.node
+function mt:removeNode(node)
+ for _, c in ipairs(node) do
+ if c.type == 'global' and c.cate == 'type' then
+ self:remove(c.name)
+ end
+ end
+end
+
---@return fun():vm.object
function mt:eachObject()
- self:_expand()
local i = 0
return function ()
i = i + 1
@@ -197,12 +220,21 @@ function mt:eachObject()
end
end
----@param source parser.object | vm.generic
+---@return vm.node
+function mt:copy()
+ return vm.createNode(self)
+end
+
+---@param source vm.object
---@param node vm.node | vm.object
---@param cover? boolean
function vm.setNode(source, node, cover)
if not node then
- error('Can not set nil node')
+ if TEST then
+ error('Can not set nil node')
+ else
+ log.error('Can not set nil node')
+ end
end
if source.type == 'global' then
error('Can not set node to global')
@@ -216,13 +248,14 @@ function vm.setNode(source, node, cover)
me:merge(node)
else
if node.type == 'vm.node' then
- vm.nodeCache[source] = node
+ vm.nodeCache[source] = node:copy()
else
vm.nodeCache[source] = vm.createNode(node)
end
end
end
+---@param source vm.object
---@return vm.node?
function vm.getNode(source)
return vm.nodeCache[source]
@@ -256,11 +289,16 @@ function vm.clearNodeCache()
vm.nodeCache = {}
end
+local ID = 0
+
---@param a? vm.node | vm.object
---@param b? vm.node | vm.object
---@return vm.node
function vm.createNode(a, b)
- local node = setmetatable({}, mt)
+ ID = ID + 1
+ local node = setmetatable({
+ id = ID,
+ }, mt)
if a then
node:merge(a)
end
diff --git a/script/vm/ref.lua b/script/vm/ref.lua
index 65e8fdab..545c294a 100644
--- a/script/vm/ref.lua
+++ b/script/vm/ref.lua
@@ -2,8 +2,6 @@
local vm = require 'vm.vm'
local util = require 'utility'
local guide = require 'parser.guide'
-local localID = require 'vm.local-id'
-local globalMgr = require 'vm.global-manager'
local files = require 'files'
local await = require 'await'
local progress = require 'progress'
@@ -242,7 +240,7 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
- local idSources = localID.getSources(source)
+ local idSources = vm.getLocalSources(source)
if not idSources then
return
end
@@ -291,7 +289,7 @@ end
---@async
---@param source parser.object
----@param fileNotify fun(uri: uri): boolean
+---@param fileNotify? fun(uri: uri): boolean
function vm.getRefs(source, fileNotify)
local results = {}
local mark = {}
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
new file mode 100644
index 00000000..9fe0f172
--- /dev/null
+++ b/script/vm/runner.lua
@@ -0,0 +1,444 @@
+---@class vm
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+
+---@class vm.runner
+---@field loc parser.object
+---@field mainBlock parser.object
+---@field blocks table<parser.object, true>
+---@field steps vm.runner.step[]
+local mt = {}
+mt.__index = mt
+mt.index = 1
+
+---@class parser.object
+---@field _casts parser.object[]
+
+---@class vm.runner.step
+---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast'
+---@field pos integer
+---@field order? integer
+---@field node? vm.node
+---@field object? parser.object
+---@field name? string
+---@field cast? parser.object
+---@field tag? string
+---@field copy? boolean
+---@field new? boolean
+---@field ref1? vm.runner.step
+---@field ref2? vm.runner.step
+
+---@param filter parser.object
+---@param outStep vm.runner.step
+---@param blockStep vm.runner.step
+function mt:_compileNarrowByFilter(filter, outStep, blockStep)
+ if not filter then
+ return
+ end
+ if filter.type == 'paren' then
+ if filter.exp then
+ self:_compileNarrowByFilter(filter.exp, outStep, blockStep)
+ end
+ return
+ end
+ if filter.type == 'unary' then
+ if not filter.op
+ or not filter[1] then
+ return
+ end
+ if filter.op.type == 'not' then
+ local exp = filter[1]
+ if exp.type == 'getlocal' and exp.node == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'falsy',
+ pos = filter.finish,
+ new = true,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'truthy',
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ end
+ end
+ elseif filter.type == 'binary' then
+ if not filter.op
+ or not filter[1]
+ or not filter[2] then
+ return
+ end
+ if filter.op.type == 'and' then
+ local dummyStep = {
+ type = 'save',
+ copy = true,
+ ref1 = outStep,
+ pos = filter.start - 1,
+ }
+ self.steps[#self.steps+1] = dummyStep
+ self:_compileNarrowByFilter(filter[1], dummyStep, blockStep)
+ self:_compileNarrowByFilter(filter[2], dummyStep, blockStep)
+ end
+ if filter.op.type == 'or' then
+ self:_compileNarrowByFilter(filter[1], outStep, blockStep)
+ local dummyStep = {
+ type = 'push',
+ copy = true,
+ ref1 = outStep,
+ pos = filter.op.finish,
+ }
+ self.steps[#self.steps+1] = dummyStep
+ self:_compileNarrowByFilter(filter[2], outStep, dummyStep)
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'or reset',
+ ref1 = blockStep,
+ pos = filter.finish,
+ }
+ end
+ if filter.op.type == '=='
+ or filter.op.type == '~=' then
+ local loc, exp
+ for i = 1, 2 do
+ loc = filter[i]
+ if loc.type == 'getlocal' and loc.node == self.loc then
+ exp = filter[i % 2 + 1]
+ break
+ end
+ end
+ if not loc or not exp then
+ return
+ end
+ if guide.isLiteral(exp) then
+ if filter.op.type == '==' then
+ self.steps[#self.steps+1] = {
+ type = 'remove',
+ name = exp.type,
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'as',
+ name = exp.type,
+ pos = filter.finish,
+ new = true,
+ }
+ end
+ if filter.op.type == '~=' then
+ self.steps[#self.steps+1] = {
+ type = 'as',
+ name = exp.type,
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'remove',
+ name = exp.type,
+ pos = filter.finish,
+ new = true,
+ }
+ end
+ end
+ end
+ else
+ if filter.type == 'getlocal' and filter.node == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'truthy',
+ pos = filter.finish,
+ new = true,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'falsy',
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ end
+ end
+end
+
+---@param block parser.object
+function mt:_compileBlock(block)
+ if self.blocks[block] then
+ return
+ end
+ self.blocks[block] = true
+ if block == self.mainBlock then
+ return
+ end
+
+ local parentBlock = guide.getParentBlock(block)
+ self:_compileBlock(parentBlock)
+
+ if block.type == 'if' then
+ ---@type vm.runner.step[]
+ local finals = {}
+ for i, childBlock in ipairs(block) do
+ local blockStep = {
+ type = 'save',
+ tag = 'block',
+ copy = true,
+ pos = childBlock.start,
+ }
+ local outStep = {
+ type = 'save',
+ tag = 'out',
+ copy = true,
+ pos = childBlock.start,
+ }
+ self.steps[#self.steps+1] = blockStep
+ self.steps[#self.steps+1] = outStep
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ ref1 = blockStep,
+ pos = childBlock.start,
+ }
+ self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep)
+ if not childBlock.hasReturn
+ and not childBlock.hasGoTo
+ and not childBlock.hasBreak then
+ local finalStep = {
+ type = 'save',
+ pos = childBlock.finish,
+ tag = 'final #' .. i,
+ }
+ finals[#finals+1] = finalStep
+ self.steps[#self.steps+1] = finalStep
+ end
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'reset child',
+ ref1 = outStep,
+ pos = childBlock.finish,
+ }
+ end
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'reset if',
+ pos = block.finish,
+ copy = true,
+ }
+ for _, final in ipairs(finals) do
+ self.steps[#self.steps+1] = {
+ type = 'merge',
+ ref2 = final,
+ pos = block.finish,
+ }
+ end
+ end
+
+ if block.type == 'function'
+ or block.type == 'while'
+ or block.type == 'loop'
+ or block.type == 'in'
+ or block.type == 'repeat'
+ or block.type == 'for' then
+ local savePoint = {
+ type = 'save',
+ copy = true,
+ pos = block.start,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ copy = true,
+ pos = block.start,
+ }
+ self.steps[#self.steps+1] = savePoint
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ pos = block.finish,
+ ref1 = savePoint,
+ }
+ end
+end
+
+---@return parser.object[]
+function mt:_getCasts()
+ local root = guide.getRoot(self.loc)
+ if not root._casts then
+ root._casts = {}
+ local docs = root.docs
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.cast' and doc.loc then
+ root._casts[#root._casts+1] = doc
+ end
+ end
+ end
+ return root._casts
+end
+
+function mt:_preCompile()
+ local startPos = self.loc.start
+ local finishPos = 0
+
+ for _, ref in ipairs(self.loc.ref) do
+ self.steps[#self.steps+1] = {
+ type = 'object',
+ object = ref,
+ pos = ref.range or ref.start,
+ }
+ if ref.start > finishPos then
+ finishPos = ref.start
+ end
+ local block = guide.getParentBlock(ref)
+ self:_compileBlock(block)
+ end
+
+ for i, step in ipairs(self.steps) do
+ if step.type ~= 'object' then
+ step.order = i
+ end
+ end
+
+ local casts = self:_getCasts()
+ for _, cast in ipairs(casts) do
+ if cast.loc[1] == self.loc[1]
+ and cast.start > startPos
+ and cast.finish < finishPos
+ and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'cast',
+ cast = cast,
+ pos = cast.start,
+ }
+ end
+ end
+
+ table.sort(self.steps, function (a, b)
+ if a.pos == b.pos then
+ return (a.order or 0) < (b.order or 0)
+ else
+ return a.pos < b.pos
+ end
+ end)
+end
+
+---@param loc parser.object
+---@param node vm.node
+---@return vm.node
+local function checkAssert(loc, node)
+ local parent = loc.parent
+ if parent.type == 'binary' then
+ if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then
+ local exp
+ for i = 1, 2 do
+ if parent[i] == loc then
+ exp = parent[i % 2 + 1]
+ end
+ end
+ if exp and guide.isLiteral(exp) then
+ local callargs = parent.parent
+ if callargs.type == 'callargs'
+ and callargs.parent.node.special == 'assert'
+ and callargs[1] == parent then
+ if parent.op.type == '~=' then
+ node:remove(exp.type)
+ end
+ if parent.op.type == '==' then
+ node = vm.compileNode(exp)
+ end
+ end
+ end
+ end
+ end
+ if parent.type == 'callargs'
+ and parent.parent.node.special == 'assert'
+ and parent[1] == loc then
+ node:setTruthy()
+ end
+ return node
+end
+
+---@param callback fun(src: parser.object, node: vm.node)
+function mt:launch(callback)
+ local topNode = vm.getNode(self.loc):copy()
+ for _, step in ipairs(self.steps) do
+ local node = step.ref1 and step.ref1.node or topNode
+ if step.type == 'truthy' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:setTruthy()
+ elseif step.type == 'falsy' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:setFalsy()
+ elseif step.type == 'as' then
+ if step.new then
+ topNode = vm.createNode(vm.getGlobal('type', step.name))
+ else
+ node:clear()
+ node:merge(vm.getGlobal('type', step.name))
+ end
+ elseif step.type == 'add' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:merge(vm.getGlobal('type', step.name))
+ elseif step.type == 'remove' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:remove(step.name)
+ elseif step.type == 'object' then
+ topNode = callback(step.object, node) or node
+ if step.object.type == 'getlocal' then
+ topNode = checkAssert(step.object, node)
+ end
+ elseif step.type == 'save' then
+ if step.copy then
+ node = node:copy()
+ end
+ step.node = node
+ elseif step.type == 'push' then
+ if step.copy then
+ node = node:copy()
+ end
+ topNode = node
+ elseif step.type == 'merge' then
+ node:merge(step.ref2.node)
+ elseif step.type == 'cast' then
+ topNode = node:copy()
+ for _, cast in ipairs(step.cast.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ topNode:addOptional()
+ end
+ if cast.extends then
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ topNode:removeOptional()
+ end
+ if cast.extends then
+ topNode:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ topNode:clear()
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ end
+ end
+ end
+ end
+end
+
+---@param loc parser.object
+---@return vm.runner
+function vm.createRunner(loc)
+ local self = setmetatable({
+ loc = loc,
+ mainBlock = guide.getParentBlock(loc),
+ blocks = {},
+ steps = {},
+ }, mt)
+
+ self:_preCompile()
+
+ return self
+end
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 2d45a5a7..fe112bc2 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -1,6 +1,6 @@
local guide = require 'parser.guide'
+---@class vm
local vm = require 'vm.vm'
-local infer = require 'vm.infer'
---@class vm.sign
---@field parent parser.object
@@ -16,12 +16,12 @@ end
---@param uri uri
---@param args parser.object
+---@param removeGeneric true?
---@return table<string, vm.node>
-function mt:resolve(uri, args)
+function mt:resolve(uri, args, removeGeneric)
if not args then
return nil
end
- local globalMgr = require 'vm.global-manager'
local resolved = {}
---@param object parser.object
@@ -33,7 +33,7 @@ function mt:resolve(uri, args)
-- 'number' -> `T`
for n in node:eachObject() do
if n.type == 'string' then
- local type = globalMgr.declareGlobal('type', n[1], guide.getUri(n))
+ local type = vm.declareGlobal('type', n[1], guide.getUri(n))
resolved[key] = vm.createNode(type, resolved[key])
end
end
@@ -48,6 +48,19 @@ function mt:resolve(uri, args)
-- number[] -> T[]
resolve(object.node, vm.compileNode(n.node))
end
+ if n.type == 'doc.type.table' then
+ -- { [integer]: number } -> T[]
+ local tvalueNode = vm.getTableValue(uri, node, 'integer')
+ if tvalueNode then
+ resolve(object.node, tvalueNode)
+ end
+ end
+ if n.type == 'global' and n.cate == 'type' then
+ -- ---@field [integer]: number -> T[]
+ vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), false, function (field)
+ resolve(object.node, vm.compileNode(field.extends))
+ end)
+ end
end
end
if object.type == 'doc.type.table' then
@@ -98,7 +111,7 @@ function mt:resolve(uri, args)
goto CONTINUE
end
end
- local view = infer.viewObject(obj)
+ local view = vm.viewObject(obj)
if view then
knownTypes[view] = true
end
@@ -114,10 +127,10 @@ function mt:resolve(uri, args)
local function buildArgNode(argNode, knownTypes)
local newArgNode = vm.createNode()
for n in argNode:eachObject() do
- if argNode:isOptional() and vm.isFalsy(n) then
+ if argNode:hasFalsy() then
goto CONTINUE
end
- local view = infer.viewObject(n)
+ local view = vm.viewObject(n)
if knownTypes[view] then
goto CONTINUE
end
@@ -156,7 +169,7 @@ function mt:resolve(uri, args)
end
---@return vm.sign
-return function ()
+function vm.createSign()
local genericMgr = setmetatable({
signList = {},
}, mt)
diff --git a/script/vm/type.lua b/script/vm/type.lua
index fa02d19e..c3264993 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -1,4 +1,3 @@
-local globalMgr = require 'vm.global-manager'
---@class vm
local vm = require 'vm.vm'
@@ -9,10 +8,10 @@ local vm = require 'vm.vm'
---@return boolean
function vm.isSubType(uri, child, parent, mark)
if type(parent) == 'string' then
- parent = vm.createNode(globalMgr.getGlobal('type', parent))
+ parent = vm.createNode(vm.getGlobal('type', parent))
end
if type(child) == 'string' then
- child = vm.createNode(globalMgr.getGlobal('type', child))
+ child = vm.createNode(vm.getGlobal('type', child))
end
if not child or not parent then
@@ -134,7 +133,7 @@ function vm.getTableKey(uri, tnode, vnode)
end
end
if tn.type == 'doc.type.array' then
- result:merge(globalMgr.getGlobal('type', 'integer'))
+ result:merge(vm.declareGlobal('type', 'integer'))
end
if tn.type == 'table' then
for _, field in ipairs(tn) do
@@ -144,10 +143,10 @@ function vm.getTableKey(uri, tnode, vnode)
end
end
if field.type == 'tablefield' then
- result:merge(globalMgr.getGlobal('type', 'string'))
+ result:merge(vm.declareGlobal('type', 'string'))
end
if field.type == 'tableexp' then
- result:merge(globalMgr.getGlobal('type', 'integer'))
+ result:merge(vm.declareGlobal('type', 'integer'))
end
end
end
diff --git a/script/vm/value.lua b/script/vm/value.lua
index a784be2a..d29ca9d0 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -17,7 +17,16 @@ function vm.test(source)
hasTrue = true
end
if n[1] == false then
- hasTrue = false
+ hasFalse = true
+ end
+ end
+ if n.type == 'global' and n.cate == 'type' then
+ if n.name == 'true' then
+ hasTrue = true
+ end
+ if n.name == 'false'
+ or n.name == 'nil' then
+ hasFalse = true
end
end
if n.type == 'nil' then
@@ -41,28 +50,9 @@ function vm.test(source)
end
end
----@param source parser.object
----@return boolean
-function vm.isFalsy(source)
- if source.type == 'nil' then
- return true
- end
- if source.type == 'boolean'
- or source.type == 'doc.type.boolean' then
- return source[1] == false
- end
- return false
-end
-
---@param v vm.object
---@return string?
local function getUnique(v)
- if v.type == 'local' then
- return ('loc:%s@%d'):format(guide.getUri(v), v.start)
- end
- if v.type == 'global' then
- return ('%s:%s'):format(v.cate, v.name)
- end
if v.type == 'boolean' then
if v[1] == nil then
return false
diff --git a/script/vm/vm.lua b/script/vm/vm.lua
index 3c1762bf..8117d311 100644
--- a/script/vm/vm.lua
+++ b/script/vm/vm.lua
@@ -23,6 +23,7 @@ function m.getSpecial(source)
return source.special
end
+---@return string?
function m.getKeyName(source)
if not source then
return nil
diff --git a/script/workspace/loading.lua b/script/workspace/loading.lua
index f40c08c6..66e0a3aa 100644
--- a/script/workspace/loading.lua
+++ b/script/workspace/loading.lua
@@ -65,7 +65,7 @@ function mt:checkMaxPreload(uri)
end
---@param uri uri
----@param libraryUri boolean
+---@param libraryUri? uri
---@async
function mt:loadFile(uri, libraryUri)
if files.isLua(uri) then
diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua
index 91923bb8..33f8784d 100644
--- a/script/workspace/workspace.lua
+++ b/script/workspace/workspace.lua
@@ -68,9 +68,10 @@ local globInteferFace = {
type = function (path)
local result
pcall(function ()
- if fs.is_directory(fs.path(path)) then
+ local status = fs.symlink_status(path):type()
+ if status == 'directory' then
result = 'directory'
- else
+ elseif status == 'regular' then
result = 'file'
end
end)
@@ -78,7 +79,7 @@ local globInteferFace = {
end,
list = function (path)
local fullPath = fs.path(path)
- if not fs.exists(fullPath) then
+ if fs.symlink_status(fullPath):type() ~= 'directory' then
return nil
end
local paths = {}
@@ -332,6 +333,8 @@ function m.findUrisByFilePath(path)
return results
end
+---@param path string
+---@return string
function m.normalize(path)
if not path then
return nil