summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
authorCppCXY <812125110@qq.com>2022-08-11 19:36:36 +0800
committerCppCXY <812125110@qq.com>2022-08-11 19:36:36 +0800
commitff9103ae4001d8e520171b99cd192997fc689bc9 (patch)
tree04c0b685e81aac48210604dc12d24b91862a36d9 /script
parent40f191a85ea21bb64c427f9dab4bc597e2a0ea1b (diff)
parent82bcfef9037c26681993c94b2f92b68d335de3c6 (diff)
downloadlua-language-server-ff9103ae4001d8e520171b99cd192997fc689bc9.zip
Merge branch 'master' of github.com:CppCXY/lua-language-server
Diffstat (limited to 'script')
-rw-r--r--script/SDBMHash.lua60
-rw-r--r--script/await.lua3
-rw-r--r--script/cli/check.lua4
-rw-r--r--script/client.lua26
-rw-r--r--script/config/config.lua285
-rw-r--r--script/config/loader.lua9
-rw-r--r--script/config/template.lua388
-rw-r--r--script/core/code-action.lua88
-rw-r--r--script/core/collector.lua188
-rw-r--r--script/core/color.lua79
-rw-r--r--script/core/command/autoRequire.lua7
-rw-r--r--script/core/command/removeSpace.lua13
-rw-r--r--script/core/command/solve.lua2
-rw-r--r--script/core/completion/completion.lua339
-rw-r--r--script/core/definition.lua11
-rw-r--r--script/core/diagnostics/ambiguity-1.lua4
-rw-r--r--script/core/diagnostics/assign-type-mismatch.lua117
-rw-r--r--script/core/diagnostics/cast-local-type.lua50
-rw-r--r--script/core/diagnostics/cast-type-mismatch.lua45
-rw-r--r--script/core/diagnostics/circle-doc-class.lua3
-rw-r--r--script/core/diagnostics/close-non-object.lua9
-rw-r--r--script/core/diagnostics/code-after-break.lua4
-rw-r--r--script/core/diagnostics/codestyle-check.lua2
-rw-r--r--script/core/diagnostics/count-down-loop.lua13
-rw-r--r--script/core/diagnostics/deprecated.lua2
-rw-r--r--script/core/diagnostics/different-requires.lua2
-rw-r--r--script/core/diagnostics/duplicate-doc-alias.lua18
-rw-r--r--script/core/diagnostics/duplicate-doc-field.lua11
-rw-r--r--script/core/diagnostics/duplicate-index.lua5
-rw-r--r--script/core/diagnostics/duplicate-set-field.lua12
-rw-r--r--script/core/diagnostics/empty-block.lua7
-rw-r--r--script/core/diagnostics/global-in-nil-env.lua68
-rw-r--r--script/core/diagnostics/init.lua94
-rw-r--r--script/core/diagnostics/lowercase-global.lua6
-rw-r--r--script/core/diagnostics/missing-parameter.lua57
-rw-r--r--script/core/diagnostics/missing-return-value.lua66
-rw-r--r--script/core/diagnostics/missing-return.lua86
-rw-r--r--script/core/diagnostics/need-check-nil.lua10
-rw-r--r--script/core/diagnostics/newfield-call.lua18
-rw-r--r--script/core/diagnostics/newline-call.lua13
-rw-r--r--script/core/diagnostics/no-unknown.lua29
-rw-r--r--script/core/diagnostics/not-yieldable.lua4
-rw-r--r--script/core/diagnostics/param-type-mismatch.lua72
-rw-r--r--script/core/diagnostics/redefined-local.lua5
-rw-r--r--script/core/diagnostics/redundant-parameter.lua73
-rw-r--r--script/core/diagnostics/redundant-return-value.lua73
-rw-r--r--script/core/diagnostics/return-type-mismatch.lua76
-rw-r--r--script/core/diagnostics/spell-check.lua34
-rw-r--r--script/core/diagnostics/trailing-space.lua20
-rw-r--r--script/core/diagnostics/type-check.lua3
-rw-r--r--script/core/diagnostics/unbalanced-assignments.lua22
-rw-r--r--script/core/diagnostics/undefined-doc-name.lua2
-rw-r--r--script/core/diagnostics/undefined-doc-param.lua42
-rw-r--r--script/core/diagnostics/undefined-env-child.lua32
-rw-r--r--script/core/diagnostics/undefined-field.lua8
-rw-r--r--script/core/diagnostics/undefined-global.lua10
-rw-r--r--script/core/diagnostics/unknown-cast-variable.lua32
-rw-r--r--script/core/diagnostics/unknown-diag-code.lua4
-rw-r--r--script/core/diagnostics/unknown-operator.lua36
-rw-r--r--script/core/diagnostics/unreachable-code.lua84
-rw-r--r--script/core/diagnostics/unused-function.lua5
-rw-r--r--script/core/diagnostics/unused-local.lua33
-rw-r--r--script/core/diagnostics/unused-vararg.lua3
-rw-r--r--script/core/find-source.lua2
-rw-r--r--script/core/folding.lua12
-rw-r--r--script/core/formatting.lua9
-rw-r--r--script/core/hint.lua110
-rw-r--r--script/core/hover/args.lua16
-rw-r--r--script/core/hover/description.lua218
-rw-r--r--script/core/hover/init.lua30
-rw-r--r--script/core/hover/label.lua19
-rw-r--r--script/core/hover/name.lua13
-rw-r--r--script/core/hover/return.lua53
-rw-r--r--script/core/hover/table.lua18
-rw-r--r--script/core/jump-source.lua62
-rw-r--r--script/core/look-backward.lua14
-rw-r--r--script/core/reference.lua8
-rw-r--r--script/core/rename.lua35
-rw-r--r--script/core/semantic-tokens.lua59
-rw-r--r--script/core/signature.lua33
-rw-r--r--script/core/type-definition.lua8
-rw-r--r--script/doctor.lua2
-rw-r--r--script/file-uri.lua28
-rw-r--r--script/files.lua47
-rw-r--r--script/filewatch.lua11
-rw-r--r--script/fs-utility.lua34
-rw-r--r--script/glob/gitignore.lua9
-rw-r--r--script/jsonrpc.lua1
-rw-r--r--script/lclient.lua3
-rw-r--r--script/library.lua102
-rw-r--r--script/linked-table.lua34
-rw-r--r--script/meta/bee/filesystem.lua91
-rw-r--r--script/parser/ast.lua1997
-rw-r--r--script/parser/calcline.lua94
-rw-r--r--script/parser/compile.lua4299
-rw-r--r--script/parser/grammar.lua573
-rw-r--r--script/parser/guide.lua247
-rw-r--r--script/parser/init.lua5
-rw-r--r--script/parser/luadoc.lua543
-rw-r--r--script/parser/newparser.lua3855
-rw-r--r--script/parser/parse.lua63
-rw-r--r--script/parser/split.lua9
-rw-r--r--script/parser/tokens.lua12
-rw-r--r--script/plugin.lua32
-rw-r--r--script/progress.lua4
-rw-r--r--script/proto/converter.lua2
-rw-r--r--script/proto/define.lua154
-rw-r--r--script/proto/diagnostic.lua267
-rw-r--r--script/provider/build-meta.lua155
-rw-r--r--script/provider/diagnostic.lua221
-rw-r--r--script/provider/formatting.lua39
-rw-r--r--script/provider/markdown.lua24
-rw-r--r--script/provider/provider.lua307
-rw-r--r--script/provider/spell.lua53
-rw-r--r--script/pub/pub.lua2
-rw-r--r--script/service/service.lua2
-rw-r--r--script/service/telemetry.lua1
-rw-r--r--script/utility.lua77
-rw-r--r--script/vm/compiler.lua1470
-rw-r--r--script/vm/def.lua91
-rw-r--r--script/vm/doc.lua59
-rw-r--r--script/vm/field.lua2
-rw-r--r--script/vm/function.lua245
-rw-r--r--script/vm/generic.lua19
-rw-r--r--script/vm/global.lua67
-rw-r--r--script/vm/infer.lua211
-rw-r--r--script/vm/init.lua4
-rw-r--r--script/vm/local-id.lua57
-rw-r--r--script/vm/node.lua223
-rw-r--r--script/vm/operator.lua368
-rw-r--r--script/vm/ref.lua94
-rw-r--r--script/vm/runner.lua677
-rw-r--r--script/vm/sign.lua67
-rw-r--r--script/vm/type.lua372
-rw-r--r--script/vm/value.lua37
-rw-r--r--script/vm/vm.lua14
-rw-r--r--script/workspace/require-path.lua320
-rw-r--r--script/workspace/scope.lua29
-rw-r--r--script/workspace/workspace.lua92
139 files changed, 11313 insertions, 10249 deletions
diff --git a/script/SDBMHash.lua b/script/SDBMHash.lua
new file mode 100644
index 00000000..48728aec
--- /dev/null
+++ b/script/SDBMHash.lua
@@ -0,0 +1,60 @@
+local byte = string.byte
+local max = 0x7fffffff
+
+---@class SDBMHash
+local mt = {}
+mt.__index = mt
+
+mt.cache = nil
+
+---@param str string
+---@return integer
+function mt:rawHash(str)
+ local id = 0
+ for i = 1, #str do
+ local b = byte(str, i, i)
+ id = id * 65599 + b
+ end
+ return id & max
+end
+
+---@param str string
+---@return integer
+function mt:hash(str)
+ local id = self:rawHash(str)
+ local other = self.cache[id]
+ if other == nil or str == other then
+ self.cache[id] = str
+ self.cache[str] = id
+ return id
+ else
+ log.warn(('哈希碰撞:[%s] -> [%s]: [%d]'):format(str, other, id))
+ for i = 1, max do
+ local newId = (id + i) % max
+ if not self.cache[newId] then
+ self.cache[newId] = str
+ self.cache[str] = newId
+ return newId
+ end
+ end
+ error(('哈希碰撞解决失败:[%s] -> [%s]: [%d]'):format(str, other, id))
+ end
+end
+
+function mt:setCache(t)
+ self.cache = t
+end
+
+function mt:getCache()
+ return self.cache
+end
+
+mt.__call = mt.hash
+
+---@return SDBMHash
+return function ()
+ local self = setmetatable({
+ cache = {}
+ }, mt)
+ return self
+end
diff --git a/script/await.lua b/script/await.lua
index 4fb81cd8..fa2aea13 100644
--- a/script/await.lua
+++ b/script/await.lua
@@ -132,9 +132,6 @@ end
---@param callback function
---@async
function m.wait(callback, ...)
- if not coroutine.isyieldable() then
- return
- end
local co = coroutine.running()
local resumed
callback(function (...)
diff --git a/script/cli/check.lua b/script/cli/check.lua
index dd2e7737..4df94c59 100644
--- a/script/cli/check.lua
+++ b/script/cli/check.lua
@@ -50,14 +50,14 @@ lclient():start(function (client)
ws.awaitReady(rootUri)
- local disables = config.get(rootUri, 'Lua.diagnostics.disable')
+ local disables = util.arrayToHash(config.get(rootUri, 'Lua.diagnostics.disable'))
for name, serverity in pairs(define.DiagnosticDefaultSeverity) do
serverity = config.get(rootUri, 'Lua.diagnostics.severity')[name] or 'Warning'
if define.DiagnosticSeverity[serverity] > checkLevel then
disables[name] = true
end
end
- config.set(nil, 'Lua.diagnostics.disable', disables)
+ config.set(nil, 'Lua.diagnostics.disable', util.getTableKeys(disables, true))
local uris = files.getAllUris(rootUri)
local max = #uris
diff --git a/script/client.lua b/script/client.lua
index 2ce29803..7432e60b 100644
--- a/script/client.lua
+++ b/script/client.lua
@@ -12,6 +12,7 @@ local scope = require 'workspace.scope'
local inspect = require 'inspect'
local m = {}
+m._eventList = {}
function m.client(newClient)
if newClient then
@@ -112,7 +113,7 @@ end
---@param type message.type
---@param message string
---@param titles string[]
----@param callback fun(action: string, index: integer)
+---@param callback fun(action?: string, index?: integer)
function m.requestMessage(type, message, titles, callback)
proto.notify('window/logMessage', {
type = define.MessageType[type] or 3,
@@ -239,6 +240,9 @@ local function tryModifySpecifiedConfig(uri, finalChanges)
return false
end
local path = workspace.getAbsolutePath(uri, CONFIGPATH)
+ if not path then
+ return false
+ end
util.saveFile(path, json.beautify(scp:get('lastLocalConfig'), { indent = ' ' }))
return true
end
@@ -252,7 +256,10 @@ local function tryModifyRC(uri, finalChanges, create)
if not path then
return false
end
- path = fs.exists(path) and path or workspace.getAbsolutePath(uri, '.luarc.json')
+ path = fs.exists(fs.path(path)) and path or workspace.getAbsolutePath(uri, '.luarc.json')
+ if not path then
+ return false
+ end
local buf = util.loadFile(path)
if not buf and not create then
return false
@@ -389,8 +396,22 @@ function m.editText(uri, edits)
})
end
+---@param callback async fun()
+function m.event(callback)
+ m._eventList[#m._eventList+1] = callback
+end
+
+function m._callEvent(ev)
+ for _, callback in ipairs(m._eventList) do
+ await.call(function ()
+ callback(ev)
+ end)
+ end
+end
+
function m.setReady()
m._ready = true
+ m._callEvent('ready')
end
function m.isReady()
@@ -415,6 +436,7 @@ function m.init(t)
lang(LOCALE or t.locale)
converter.setOffsetEncoding(m.getOffsetEncoding())
hookPrint()
+ m._callEvent('init')
end
return m
diff --git a/script/config/config.lua b/script/config/config.lua
index bde214b0..70e83fc8 100644
--- a/script/config/config.lua
+++ b/script/config/config.lua
@@ -1,224 +1,10 @@
-local util = require 'utility'
-local define = require 'proto.define'
-local timer = require 'timer'
-local scope = require 'workspace.scope'
+local util = require 'utility'
+local timer = require 'timer'
+local scope = require 'workspace.scope'
+local template = require 'config.template'
---@alias config.source '"client"'|'"path"'|'"local"'
----@class config.unit
----@field caller function
-local mt = {}
-mt.__index = mt
-
-function mt:__call(...)
- self:caller(...)
- return self
-end
-
-function mt:__shr(default)
- self.default = default
- return self
-end
-
-local units = {}
-
-local function register(name, default, checker, loader, caller)
- units[name] = {
- default = default,
- checker = checker,
- loader = loader,
- caller = caller,
- }
-end
-
-local Type = setmetatable({}, { __index = function (_, name)
- local unit = {}
- for k, v in pairs(units[name]) do
- unit[k] = v
- end
- return setmetatable(unit, mt)
-end })
-
-register('Boolean', false, function (self, v)
- return type(v) == 'boolean'
-end, function (self, v)
- return v
-end)
-
-register('Integer', 0, function (self, v)
- return type(v) == 'number'
-end, function (self, v)
- return math.floor(v)
-end)
-
-register('String', '', function (self, v)
- return type(v) == 'string'
-end, function (self, v)
- return tostring(v)
-end)
-
-register('Nil', nil, function (self, v)
- return type(v) == 'nil'
-end, function (self, v)
- return nil
-end)
-
-register('Array', {}, function (self, value)
- return type(value) == 'table'
-end, function (self, value)
- local t = {}
- for _, v in ipairs(value) do
- if self.sub:checker(v) then
- t[#t+1] = self.sub:loader(v)
- end
- end
- return t
-end, function (self, sub)
- self.sub = sub
-end)
-
-register('Hash', {}, function (self, value)
- if type(value) == 'table' then
- if #value == 0 then
- for k, v in pairs(value) do
- if not self.subkey:checker(k)
- or not self.subvalue:checker(v) then
- return false
- end
- end
- else
- if not self.subvalue:checker(true) then
- return false
- end
- for _, v in ipairs(value) do
- if not self.subkey:checker(v) then
- return false
- end
- end
- end
- return true
- end
- if type(value) == 'string' then
- return self.subkey:checker('')
- and self.subvalue:checker(true)
- end
-end, function (self, value)
- if type(value) == 'table' then
- local t = {}
- if #value == 0 then
- for k, v in pairs(value) do
- t[k] = v
- end
- else
- for _, k in pairs(value) do
- t[k] = true
- end
- end
- return t
- end
- if type(value) == 'string' then
- local t = {}
- for s in value:gmatch('[^' .. self.sep .. ']+') do
- t[s] = true
- end
- return t
- end
-end, function (self, subkey, subvalue, sep)
- self.subkey = subkey
- self.subvalue = subvalue
- self.sep = sep
-end)
-
-register('Or', nil, function (self, value)
- for _, sub in ipairs(self.subs) do
- if sub:checker(value) then
- return true
- end
- end
- return false
-end, function (self, value)
- for _, sub in ipairs(self.subs) do
- if sub:checker(value) then
- return sub:loader(value)
- end
- end
-end, function (self, ...)
- self.subs = { ... }
-end)
-
-local Template = {
- ['Lua.runtime.version'] = Type.String >> 'Lua 5.4',
- ['Lua.runtime.path'] = Type.Array(Type.String) >> {
- "?.lua",
- "?/init.lua",
- },
- ['Lua.runtime.pathStrict'] = Type.Boolean >> false,
- ['Lua.runtime.special'] = Type.Hash(Type.String, Type.String),
- ['Lua.runtime.meta'] = Type.String >> '${version} ${language} ${encoding}',
- ['Lua.runtime.unicodeName'] = Type.Boolean,
- ['Lua.runtime.nonstandardSymbol'] = Type.Hash(Type.String, Type.Boolean, ';'),
- ['Lua.runtime.plugin'] = Type.String,
- ['Lua.runtime.fileEncoding'] = Type.String >> 'utf8',
- ['Lua.runtime.builtin'] = Type.Hash(Type.String, Type.String),
- ['Lua.diagnostics.enable'] = Type.Boolean >> true,
- ['Lua.diagnostics.globals'] = Type.Hash(Type.String, Type.Boolean, ';'),
- ['Lua.diagnostics.disable'] = Type.Hash(Type.String, Type.Boolean, ';'),
- ['Lua.diagnostics.severity'] = Type.Hash(Type.String, Type.String)
- >> util.deepCopy(define.DiagnosticDefaultSeverity),
- ['Lua.diagnostics.neededFileStatus'] = Type.Hash(Type.String, Type.String)
- >> util.deepCopy(define.DiagnosticDefaultNeededFileStatus),
- ['Lua.diagnostics.workspaceDelay'] = Type.Integer >> 5,
- ['Lua.diagnostics.workspaceRate'] = Type.Integer >> 100,
- ['Lua.diagnostics.libraryFiles'] = Type.String >> 'Opened',
- ['Lua.diagnostics.ignoredFiles'] = Type.String >> 'Opened',
- ['Lua.workspace.ignoreDir'] = Type.Array(Type.String),
- ['Lua.workspace.ignoreSubmodules'] = Type.Boolean >> true,
- ['Lua.workspace.useGitIgnore'] = Type.Boolean >> true,
- ['Lua.workspace.maxPreload'] = Type.Integer >> 3000,
- ['Lua.workspace.preloadFileSize'] = Type.Integer >> 500,
- ['Lua.workspace.library'] = Type.Hash(Type.String, Type.Boolean, ';'),
- ['Lua.workspace.checkThirdParty'] = Type.Boolean >> true,
- ['Lua.workspace.userThirdParty'] = Type.Array(Type.String),
- ['Lua.completion.enable'] = Type.Boolean >> true,
- ['Lua.completion.callSnippet'] = Type.String >> 'Disable',
- ['Lua.completion.keywordSnippet'] = Type.String >> 'Replace',
- ['Lua.completion.displayContext'] = Type.Integer >> 0,
- ['Lua.completion.workspaceWord'] = Type.Boolean >> true,
- ['Lua.completion.showWord'] = Type.String >> 'Fallback',
- ['Lua.completion.autoRequire'] = Type.Boolean >> true,
- ['Lua.completion.showParams'] = Type.Boolean >> true,
- ['Lua.completion.requireSeparator'] = Type.String >> '.',
- ['Lua.completion.postfix'] = Type.String >> '@',
- ['Lua.signatureHelp.enable'] = Type.Boolean >> true,
- ['Lua.hover.enable'] = Type.Boolean >> true,
- ['Lua.hover.viewString'] = Type.Boolean >> true,
- ['Lua.hover.viewStringMax'] = Type.Integer >> 1000,
- ['Lua.hover.viewNumber'] = Type.Boolean >> true,
- ['Lua.hover.previewFields'] = Type.Integer >> 20,
- ['Lua.hover.enumsLimit'] = Type.Integer >> 5,
- ['Lua.hover.expandAlias'] = Type.Boolean >> true,
- ['Lua.semantic.enable'] = Type.Boolean >> true,
- ['Lua.semantic.variable'] = Type.Boolean >> true,
- ['Lua.semantic.annotation'] = Type.Boolean >> true,
- ['Lua.semantic.keyword'] = Type.Boolean >> false,
- ['Lua.hint.enable'] = Type.Boolean >> false,
- ['Lua.hint.paramType'] = Type.Boolean >> true,
- ['Lua.hint.setType'] = Type.Boolean >> false,
- ['Lua.hint.paramName'] = Type.String >> 'All',
- ['Lua.hint.await'] = Type.Boolean >> true,
- ['Lua.hint.arrayIndex'] = Type.Boolean >> 'Auto',
- ['Lua.window.statusBar'] = Type.Boolean >> true,
- ['Lua.window.progressBar'] = Type.Boolean >> true,
- ['Lua.format.enable'] = Type.Boolean >> true,
- ['Lua.format.defaultConfig'] = Type.Hash(Type.String, Type.String)
- >> {},
- ['Lua.telemetry.enable'] = Type.Or(Type.Boolean >> false, Type.Nil) >> nil,
- ['files.associations'] = Type.Hash(Type.String, Type.String),
- ['files.exclude'] = Type.Hash(Type.String, Type.Boolean),
- ['editor.semanticHighlighting.enabled'] = Type.Or(Type.Boolean, Type.String),
- ['editor.acceptSuggestionOnEnter'] = Type.String >> 'on',
-}
-
---@class config.api
local m = {}
m.watchList = {}
@@ -241,7 +27,7 @@ local function update(scp, key, nowValue, rawValue)
raw[key] = rawValue
end
----@param uri uri
+---@param uri? uri
---@param key? string
---@return scope
local function getScope(uri, key)
@@ -252,7 +38,7 @@ local function getScope(uri, key)
end
end
if uri then
- ---@type scope
+ ---@type scope?
local scp = scope.getFolder(uri) or scope.getLinkedScope(uri)
if scp then
if not key
@@ -268,7 +54,7 @@ end
---@param key string
---@param value any
function m.setByScope(scp, key, value)
- local unit = Template[key]
+ local unit = template[key]
if not unit then
return false
end
@@ -284,10 +70,12 @@ function m.setByScope(scp, key, value)
return true
end
----@param uri uri
+---@param uri? uri
---@param key string
---@param value any
function m.set(uri, key, value)
+ local unit = template[key]
+ assert(unit, 'unknown key: ' .. key)
local scp = getScope(uri)
local oldValue = m.get(uri, key)
m.setByScope(scp, key, value)
@@ -300,14 +88,10 @@ function m.set(uri, key, value)
end
function m.add(uri, key, value)
- local unit = Template[key]
- if not unit then
- return false
- end
+ local unit = template[key]
+ assert(unit, 'unknown key: ' .. key)
local list = m.getRaw(uri, key)
- if type(list) ~= 'table' then
- return false
- end
+ assert(type(list) == 'table', 'not a list: ' .. key)
local copyed = {}
for i, v in ipairs(list) do
if util.equal(v, value) then
@@ -326,15 +110,32 @@ function m.add(uri, key, value)
return false
end
-function m.prop(uri, key, prop, value)
- local unit = Template[key]
- if not unit then
- return false
+function m.remove(uri, key, value)
+ local unit = template[key]
+ assert(unit, 'unknown key: ' .. key)
+ local list = m.getRaw(uri, key)
+ assert(type(list) == 'table', 'not a list: ' .. key)
+ local copyed = {}
+ for i, v in ipairs(list) do
+ if not util.equal(v, value) then
+ copyed[i] = v
+ end
end
- local map = m.getRaw(uri, key)
- if type(map) ~= 'table' then
- return false
+ local oldValue = m.get(uri, key)
+ m.set(uri, key, copyed)
+ local newValue = m.get(uri, key)
+ if not util.equal(oldValue, newValue) then
+ m.event(uri, key, newValue, oldValue)
+ return true
end
+ return false
+end
+
+function m.prop(uri, key, prop, value)
+ local unit = template[key]
+ assert(unit, 'unknown key: ' .. key)
+ local map = m.getRaw(uri, key)
+ assert(type(map) == 'table', 'not a map: ' .. key)
if util.equal(map[prop], value) then
return false
end
@@ -353,14 +154,14 @@ function m.prop(uri, key, prop, value)
return false
end
----@param uri uri
+---@param uri? uri
---@param key string
---@return any
function m.get(uri, key)
local scp = getScope(uri, key)
local value = m.getNowTable(scp)[key]
if value == nil then
- value = Template[key].default
+ value = template[key].default
end
if value == m.NULL then
value = nil
@@ -375,7 +176,7 @@ function m.getRaw(uri, key)
local scp = getScope(uri, key)
local value = m.getRawTable(scp)[key]
if value == nil then
- value = Template[key].default
+ value = template[key].default
end
if value == m.NULL then
value = nil
@@ -412,9 +213,9 @@ function m.update(scp, ...)
if m.nullSymbols[value] then
value = m.NULL
end
- if Template[fullKey] then
+ if template[fullKey] then
m.setByScope(scp, fullKey, value)
- elseif Template['Lua.' .. fullKey] then
+ elseif template['Lua.' .. fullKey] then
m.setByScope(scp, 'Lua.' .. fullKey, value)
elseif type(value) == 'table' then
expand(value, fullKey)
@@ -424,7 +225,7 @@ function m.update(scp, ...)
local news = table.pack(...)
for i = 1, news.n do
- if news[i] then
+ if type(news[i]) == 'table' then
expand(news[i])
end
end
diff --git a/script/config/loader.lua b/script/config/loader.lua
index 30711dde..5cc7139f 100644
--- a/script/config/loader.lua
+++ b/script/config/loader.lua
@@ -17,6 +17,7 @@ end
---@class config.loader
local m = {}
+---@return table?
function m.loadRCConfig(uri, filename)
local scp = scope.getScope(uri)
local path = workspace.getAbsolutePath(uri, filename)
@@ -34,11 +35,16 @@ function m.loadRCConfig(uri, filename)
errorMessage(lang.script('CONFIG_LOAD_ERROR', res))
return scp:get('lastRCConfig')
end
+ ---@cast res table
scp:set('lastRCConfig', res)
return res
end
+---@return table?
function m.loadLocalConfig(uri, filename)
+ if not filename then
+ return nil
+ end
local scp = scope.getScope(uri)
local path = workspace.getAbsolutePath(uri, filename)
if not path then
@@ -60,6 +66,7 @@ function m.loadLocalConfig(uri, filename)
errorMessage(lang.script('CONFIG_LOAD_ERROR', res))
return scp:get('lastLocalConfig')
end
+ ---@cast res table
scp:set('lastLocalConfig', res)
scp:set('lastLocalType', 'json')
return res
@@ -79,7 +86,7 @@ end
---@async
---@param uri? uri
----@return table
+---@return table?
function m.loadClientConfig(uri)
local configs = proto.awaitRequest('workspace/configuration', {
items = {
diff --git a/script/config/template.lua b/script/config/template.lua
new file mode 100644
index 00000000..60f3dbca
--- /dev/null
+++ b/script/config/template.lua
@@ -0,0 +1,388 @@
+local util = require 'utility'
+local define = require 'proto.define'
+local diag = require 'proto.diagnostic'
+
+---@class config.unit
+---@field caller function
+---@field _checker fun(self: config.unit, value: any): boolean
+---@field name string
+---@field [string] config.unit
+---@operator shl: config.unit
+---@operator shr: config.unit
+---@operator call: config.unit
+local mt = {}
+mt.__index = mt
+
+function mt:__call(...)
+ self:caller(...)
+ return self
+end
+
+function mt:__shr(default)
+ self.default = default
+ self.hasDefault = true
+ return self
+end
+
+function mt:__shl(enums)
+ self.enums = enums
+ return self
+end
+
+function mt:checker(v)
+ if self.enums then
+ local ok
+ for _, enum in ipairs(self.enums) do
+ if util.equal(enum, v) then
+ ok = true
+ break
+ end
+ end
+ if not ok then
+ return false
+ end
+ end
+ return self:_checker(v)
+end
+
+local units = {}
+
+local function register(name, default, checker, loader, caller)
+ units[name] = {
+ name = name,
+ default = default,
+ _checker = checker,
+ loader = loader,
+ caller = caller,
+ }
+end
+
+---@type config.unit
+local Type = setmetatable({}, { __index = function (_, name)
+ local unit = {}
+ for k, v in pairs(units[name]) do
+ unit[k] = v
+ end
+ return setmetatable(unit, mt)
+end })
+
+register('Boolean', false, function (self, v)
+ return type(v) == 'boolean'
+end, function (self, v)
+ return v
+end)
+
+register('Integer', 0, function (self, v)
+ return type(v) == 'number'
+end, function (self, v)
+ return math.floor(v)
+end)
+
+register('String', '', function (self, v)
+ return type(v) == 'string'
+end, function (self, v)
+ return tostring(v)
+end)
+
+register('Nil', nil, function (self, v)
+ return type(v) == 'nil'
+end, function (self, v)
+ return nil
+end)
+
+register('Array', {}, function (self, value)
+ return type(value) == 'table'
+end, function (self, value)
+ local t = {}
+ if #value == 0 then
+ for k in pairs(value) do
+ if self.sub:checker(k) then
+ t[#t+1] = self.sub:loader(k)
+ end
+ end
+ else
+ for _, v in ipairs(value) do
+ if self.sub:checker(v) then
+ t[#t+1] = self.sub:loader(v)
+ end
+ end
+ end
+ return t
+end, function (self, sub)
+ self.sub = sub
+end)
+
+register('Hash', {}, function (self, value)
+ if type(value) == 'table' then
+ if #value == 0 then
+ for k, v in pairs(value) do
+ if not self.subkey:checker(k)
+ or not self.subvalue:checker(v) then
+ return false
+ end
+ end
+ else
+ if not self.subvalue:checker(true) then
+ return false
+ end
+ for _, v in ipairs(value) do
+ if not self.subkey:checker(v) then
+ return false
+ end
+ end
+ end
+ return true
+ end
+ if type(value) == 'string' then
+ return self.subkey:checker('')
+ and self.subvalue:checker(true)
+ end
+end, function (self, value)
+ if type(value) == 'table' then
+ local t = {}
+ if #value == 0 then
+ for k, v in pairs(value) do
+ t[k] = v
+ end
+ else
+ for _, k in pairs(value) do
+ t[k] = true
+ end
+ end
+ return t
+ end
+ if type(value) == 'string' then
+ local t = {}
+ for s in value:gmatch('[^' .. self.sep .. ']+') do
+ t[s] = true
+ end
+ return t
+ end
+end, function (self, subkey, subvalue, sep)
+ self.subkey = subkey
+ self.subvalue = subvalue
+ self.sep = sep
+end)
+
+register('Or', nil, function (self, value)
+ for _, sub in ipairs(self.subs) do
+ if sub:checker(value) then
+ return true
+ end
+ end
+ return false
+end, function (self, value)
+ for _, sub in ipairs(self.subs) do
+ if sub:checker(value) then
+ return sub:loader(value)
+ end
+ end
+end, function (self, ...)
+ self.subs = { ... }
+end)
+
+local template = {
+ ['Lua.runtime.version'] = Type.String >> 'Lua 5.4' << {
+ 'Lua 5.1',
+ 'Lua 5.2',
+ 'Lua 5.3',
+ 'Lua 5.4',
+ 'LuaJIT',
+ },
+ ['Lua.runtime.path'] = Type.Array(Type.String) >> {
+ "?.lua",
+ "?/init.lua",
+ },
+ ['Lua.runtime.pathStrict'] = Type.Boolean >> false,
+ ['Lua.runtime.special'] = Type.Hash(
+ Type.String,
+ Type.String >> 'require' << {
+ '_G',
+ 'rawset',
+ 'rawget',
+ 'setmetatable',
+ 'require',
+ 'dofile',
+ 'loadfile',
+ 'pcall',
+ 'xpcall',
+ 'assert',
+ 'error',
+ 'type',
+ }
+ ),
+ ['Lua.runtime.meta'] = Type.String >> '${version} ${language} ${encoding}',
+ ['Lua.runtime.unicodeName'] = Type.Boolean,
+ ['Lua.runtime.nonstandardSymbol'] = Type.Array(Type.String << {
+ '//', '/**/',
+ '`',
+ '+=', '-=', '*=', '/=', '%=', '^=', '//=',
+ '|=', '&=', '<<=', '>>=',
+ '||', '&&', '!', '!=',
+ 'continue',
+ }),
+ ['Lua.runtime.plugin'] = Type.String,
+ ['Lua.runtime.pluginArgs'] = Type.Array(Type.String),
+ ['Lua.runtime.fileEncoding'] = Type.String >> 'utf8' << {
+ 'utf8',
+ 'ansi',
+ 'utf16le',
+ 'utf16be',
+ },
+ ['Lua.runtime.builtin'] = Type.Hash(
+ Type.String << util.getTableKeys(define.BuiltIn, true),
+ Type.String >> 'default' << {
+ 'default',
+ 'enable',
+ 'disable',
+ }
+ )
+ >> util.deepCopy(define.BuiltIn),
+ ['Lua.diagnostics.enable'] = Type.Boolean >> true,
+ ['Lua.diagnostics.globals'] = Type.Array(Type.String),
+ ['Lua.diagnostics.disable'] = Type.Array(Type.String << util.getTableKeys(diag.getDiagAndErrNameMap(), true)),
+ ['Lua.diagnostics.severity'] = Type.Hash(
+ Type.String << util.getTableKeys(define.DiagnosticDefaultNeededFileStatus, true),
+ Type.String << {
+ 'Error',
+ 'Warning',
+ 'Information',
+ 'Hint',
+ 'Error!',
+ 'Warning!',
+ 'Information!',
+ 'Hint!',
+ }
+ )
+ >> util.deepCopy(define.DiagnosticDefaultSeverity),
+ ['Lua.diagnostics.neededFileStatus'] = Type.Hash(
+ Type.String << util.getTableKeys(define.DiagnosticDefaultNeededFileStatus, true),
+ Type.String << {
+ 'Any',
+ 'Opened',
+ 'None',
+ 'Any!',
+ 'Opened!',
+ 'None!',
+ }
+ )
+ >> util.deepCopy(define.DiagnosticDefaultNeededFileStatus),
+ ['Lua.diagnostics.groupSeverity'] = Type.Hash(
+ Type.String << util.getTableKeys(define.DiagnosticDefaultGroupSeverity, true),
+ Type.String << {
+ 'Error',
+ 'Warning',
+ 'Information',
+ 'Hint',
+ 'Fallback',
+ }
+ )
+ >> util.deepCopy(define.DiagnosticDefaultGroupSeverity),
+ ['Lua.diagnostics.groupFileStatus'] = Type.Hash(
+ Type.String << util.getTableKeys(define.DiagnosticDefaultGroupFileStatus, true),
+ Type.String << {
+ 'Any',
+ 'Opened',
+ 'None',
+ 'Fallback',
+ }
+ )
+ >> util.deepCopy(define.DiagnosticDefaultGroupFileStatus),
+ ['Lua.diagnostics.disableScheme'] = Type.Array(Type.String) >> { 'git' },
+ ['Lua.diagnostics.workspaceDelay'] = Type.Integer >> 3000,
+ ['Lua.diagnostics.workspaceRate'] = Type.Integer >> 100,
+ ['Lua.diagnostics.libraryFiles'] = Type.String >> 'Opened' << {
+ 'Enable',
+ 'Opened',
+ 'Disable',
+ },
+ ['Lua.diagnostics.ignoredFiles'] = Type.String >> 'Opened' << {
+ 'Enable',
+ 'Opened',
+ 'Disable',
+ },
+ ['Lua.diagnostics.unusedLocalExclude'] = Type.Array(Type.String),
+ ['Lua.workspace.ignoreDir'] = Type.Array(Type.String) >> {
+ '.vscode',
+ },
+ ['Lua.workspace.ignoreSubmodules'] = Type.Boolean >> true,
+ ['Lua.workspace.useGitIgnore'] = Type.Boolean >> true,
+ ['Lua.workspace.maxPreload'] = Type.Integer >> 5000,
+ ['Lua.workspace.preloadFileSize'] = Type.Integer >> 500,
+ ['Lua.workspace.library'] = Type.Array(Type.String),
+ ['Lua.workspace.checkThirdParty'] = Type.Boolean >> true,
+ ['Lua.workspace.userThirdParty'] = Type.Array(Type.String),
+ ['Lua.workspace.supportScheme'] = Type.Array(Type.String) >> { 'file', 'untitled', 'git' },
+ ['Lua.completion.enable'] = Type.Boolean >> true,
+ ['Lua.completion.callSnippet'] = Type.String >> 'Disable' << {
+ 'Disable',
+ 'Both',
+ 'Replace',
+ },
+ ['Lua.completion.keywordSnippet'] = Type.String >> 'Replace' << {
+ 'Disable',
+ 'Both',
+ 'Replace',
+ },
+ ['Lua.completion.displayContext'] = Type.Integer >> 0,
+ ['Lua.completion.workspaceWord'] = Type.Boolean >> true,
+ ['Lua.completion.showWord'] = Type.String >> 'Fallback' << {
+ 'Enable',
+ 'Fallback',
+ 'Disable',
+ },
+ ['Lua.completion.autoRequire'] = Type.Boolean >> true,
+ ['Lua.completion.showParams'] = Type.Boolean >> true,
+ ['Lua.completion.requireSeparator'] = Type.String >> '.',
+ ['Lua.completion.postfix'] = Type.String >> '@',
+ ['Lua.signatureHelp.enable'] = Type.Boolean >> true,
+ ['Lua.hover.enable'] = Type.Boolean >> true,
+ ['Lua.hover.viewString'] = Type.Boolean >> true,
+ ['Lua.hover.viewStringMax'] = Type.Integer >> 1000,
+ ['Lua.hover.viewNumber'] = Type.Boolean >> true,
+ ['Lua.hover.previewFields'] = Type.Integer >> 50,
+ ['Lua.hover.enumsLimit'] = Type.Integer >> 5,
+ ['Lua.hover.expandAlias'] = Type.Boolean >> true,
+ ['Lua.semantic.enable'] = Type.Boolean >> true,
+ ['Lua.semantic.variable'] = Type.Boolean >> true,
+ ['Lua.semantic.annotation'] = Type.Boolean >> true,
+ ['Lua.semantic.keyword'] = Type.Boolean >> false,
+ ['Lua.hint.enable'] = Type.Boolean >> false,
+ ['Lua.hint.paramType'] = Type.Boolean >> true,
+ ['Lua.hint.setType'] = Type.Boolean >> false,
+ ['Lua.hint.paramName'] = Type.String >> 'All' << {
+ 'All',
+ 'Literal',
+ 'Disable',
+ },
+ ['Lua.hint.await'] = Type.Boolean >> true,
+ ['Lua.hint.arrayIndex'] = Type.String >> 'Auto' << {
+ 'Enable',
+ 'Auto',
+ 'Disable',
+ },
+ ['Lua.hint.semicolon'] = Type.String >> 'SameLine' << {
+ 'All',
+ 'SameLine',
+ 'Disable',
+ },
+ ['Lua.window.statusBar'] = Type.Boolean >> true,
+ ['Lua.window.progressBar'] = Type.Boolean >> true,
+ ['Lua.format.enable'] = Type.Boolean >> true,
+ ['Lua.format.defaultConfig'] = Type.Hash(Type.String, Type.String)
+ >> {},
+ ['Lua.spell.dict'] = Type.Array(Type.String),
+ ['Lua.telemetry.enable'] = Type.Or(Type.Boolean >> false, Type.Nil) >> nil,
+ ['Lua.misc.parameters'] = Type.Array(Type.String),
+ ['Lua.type.castNumberToInteger'] = Type.Boolean >> true,
+ ['Lua.type.weakUnionCheck'] = Type.Boolean >> false,
+ ['Lua.type.weakNilCheck'] = Type.Boolean >> false,
+
+ -- VSCode
+ ['files.associations'] = Type.Hash(Type.String, Type.String),
+ ['files.exclude'] = Type.Hash(Type.String, Type.Boolean),
+ ['editor.semanticHighlighting.enabled'] = Type.Or(Type.Boolean, Type.String),
+ ['editor.acceptSuggestionOnEnter'] = Type.String >> 'on',
+}
+
+return template
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
index 6bba0a82..4eb21ff8 100644
--- a/script/core/code-action.lua
+++ b/script/core/code-action.lua
@@ -5,11 +5,18 @@ local sp = require 'bee.subprocess'
local guide = require "parser.guide"
local converter = require 'proto.converter'
+---@param uri uri
+---@param row integer
+---@param mode string
+---@param code string
local function checkDisableByLuaDocExits(uri, row, mode, code)
if row < 0 then
return nil
end
local state = files.getState(uri)
+ if not state then
+ return nil
+ end
local lines = state.lines
if state.ast.docs and lines then
return guide.eachSourceBetween(
@@ -124,9 +131,12 @@ local function changeVersion(uri, version, results)
end
local function solveUndefinedGlobal(uri, diag, results)
- local ast = files.getState(uri)
- local start = converter.unpackRange(uri, diag.range)
- guide.eachSourceContain(ast.ast, start, function (source)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+ local start = converter.unpackRange(uri, diag.range)
+ guide.eachSourceContain(state.ast, start, function (source)
if source.type ~= 'getglobal' then
return
end
@@ -143,9 +153,12 @@ local function solveUndefinedGlobal(uri, diag, results)
end
local function solveLowercaseGlobal(uri, diag, results)
- local ast = files.getState(uri)
- local start = converter.unpackRange(uri, diag.range)
- guide.eachSourceContain(ast.ast, start, function (source)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+ local start = converter.unpackRange(uri, diag.range)
+ guide.eachSourceContain(state.ast, start, function (source)
if source.type ~= 'setglobal' then
return
end
@@ -156,8 +169,11 @@ local function solveLowercaseGlobal(uri, diag, results)
end
local function findSyntax(uri, diag)
- local ast = files.getState(uri)
- for _, err in ipairs(ast.errs) do
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+ for _, err in ipairs(state.errs) do
if err.type:lower():gsub('_', '-') == diag.code then
local range = converter.packRange(uri, err.start, err.finish)
if util.equal(range, diag.range) then
@@ -333,6 +349,8 @@ local function solveAwaitInSync(uri, diag, results)
end
local row = guide.rowColOf(parentFunction.start)
local pos = guide.positionOf(row, 0)
+ local offset = guide.positionToOffset(state, pos + 1)
+ local space = state.lua:match('[ \t]*', offset)
results[#results+1] = {
title = lang.script.ACTION_MARK_ASYNC,
kind = 'quickfix',
@@ -342,7 +360,7 @@ local function solveAwaitInSync(uri, diag, results)
{
start = pos,
finish = pos,
- newText = '---@async\n',
+ newText = space .. '---@async\n',
}
}
}
@@ -350,6 +368,51 @@ local function solveAwaitInSync(uri, diag, results)
}
end
+local function solveSpell(uri, diag, results)
+ local spell = require 'provider.spell'
+ local word = diag.data
+ if word == nil then
+ return
+ end
+
+ results[#results+1] = {
+ title = lang.script('ACTION_ADD_DICT', word),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_ADD_DICT,
+ command = 'lua.setConfig',
+ arguments = {
+ {
+ key = 'Lua.spell.dict',
+ action = 'add',
+ value = word,
+ uri = uri,
+ }
+ }
+ }
+ }
+
+ local suggests = spell.getSpellSuggest(word)
+ for _, suggest in ipairs(suggests) do
+ results[#results+1] = {
+ title = suggest,
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ start = converter.unpackPosition(uri, diag.range.start),
+ finish = converter.unpackPosition(uri, diag.range["end"]),
+ newText = suggest
+ }
+ }
+ }
+ }
+ }
+ end
+
+end
+
local function solveDiagnostic(uri, diag, start, results)
if diag.source == lang.script.DIAG_SYNTAX_CHECK then
solveSyntax(uri, diag, results)
@@ -370,6 +433,8 @@ local function solveDiagnostic(uri, diag, start, results)
solveTrailingSpace(uri, diag, results)
elseif diag.code == 'await-in-sync' then
solveAwaitInSync(uri, diag, results)
+ elseif diag.code == 'spell-check' then
+ solveSpell(uri, diag, results)
end
disableDiagnostic(uri, diag.code, start, results)
end
@@ -386,7 +451,7 @@ end
local function checkSwapParams(results, uri, start, finish)
local state = files.getState(uri)
local text = files.getText(uri)
- if not state then
+ if not state or not text then
return
end
local args = {}
@@ -554,6 +619,9 @@ end
local function checkJsonToLua(results, uri, start, finish)
local text = files.getText(uri)
local state = files.getState(uri)
+ if not state or not text then
+ return
+ end
local startOffset = guide.positionToOffset(state, start)
local finishOffset = guide.positionToOffset(state, finish)
local jsonStart = text:match('()[%{%[]', startOffset + 1)
diff --git a/script/core/collector.lua b/script/core/collector.lua
deleted file mode 100644
index a2e3ca08..00000000
--- a/script/core/collector.lua
+++ /dev/null
@@ -1,188 +0,0 @@
-local scope = require 'workspace.scope'
-
----@class collector
----@field subscribed table<uri, table<string, any>>
----@field collect table<string, table<uri, any>>
-local mt = {}
-mt.__index = mt
-
---- 订阅一个名字
----@param uri uri
----@param name string
----@param value any
-function mt:subscribe(uri, name, value)
- uri = uri or '<fallback>'
- -- 订阅部分
- local uriSubscribed = self.subscribed[uri]
- if not uriSubscribed then
- uriSubscribed = {}
- self.subscribed[uri] = uriSubscribed
- end
- uriSubscribed[name] = true
- -- 收集部分
- local nameCollect = self.collect[name]
- if not nameCollect then
- nameCollect = {}
- self.collect[name] = nameCollect
- end
- if value == nil then
- value = true
- end
- nameCollect[uri] = value
-end
-
---- 丢弃掉某个 uri 中收集的所有信息
----@param uri uri
-function mt:dropUri(uri)
- uri = uri or '<fallback>'
- local uriSubscribed = self.subscribed[uri]
- if not uriSubscribed then
- return
- end
- self.subscribed[uri] = nil
- for name in pairs(uriSubscribed) do
- self.collect[name][uri] = nil
- if not next(self.collect[name]) then
- self.collect[name] = nil
- end
- end
-end
-
-function mt:dropAll()
- self.subscribed = {}
- self.collect = {}
-end
-
---- 是否包含某个名字的订阅
----@param uri uri
----@param name string
----@return boolean
-function mt:has(uri, name)
- if self:each(uri, name)() then
- return true
- else
- return false
- end
-end
-
-local DUMMY_FUNCTION = function () end
-
----@param scp scope
-local function eachOfFolder(nameCollect, scp)
- local curi, value
-
- local function getNext()
- curi, value = next(nameCollect, curi)
- if not curi then
- return nil, nil
- end
- if scp:isChildUri(curi)
- or scp:isLinkedUri(curi) then
- return value, curi
- end
- return getNext()
- end
-
- return getNext
-end
-
----@param scp scope
-local function eachOfLinked(nameCollect, scp)
- local curi, value
-
- local function getNext()
- curi, value = next(nameCollect, curi)
- if not curi then
- return nil, nil
- end
- if scp:isChildUri(curi)
- and scp:isLinkedUri(curi) then
- return value, curi
- end
-
- local cscp = scope.getFolder(curi)
- or scope.getLinkedScope(curi)
- or scope.fallback
-
- if cscp == scp
- or cscp:isChildUri(scp.uri)
- or cscp:isLinkedUri(scp.uri) then
- return value, curi
- end
-
- return getNext()
- end
-
- return getNext
-end
-
----@param scp scope
-local function eachOfFallback(nameCollect, scp)
- local curi, value
-
- local function getNext()
- curi, value = next(nameCollect, curi)
- if not curi then
- return nil, nil
- end
- if scp:isLinkedUri(curi) then
- return value, curi
- end
-
- local cscp = scope.getFolder(curi)
- or scope.getLinkedScope(curi)
- or scope.fallback
-
- if cscp == scp then
- return value, curi
- end
-
- return getNext()
- end
-
- return getNext
-end
-
---- 迭代某个名字的订阅
----@param uri uri
----@param name string
-function mt:each(uri, name)
- uri = uri or '<fallback>'
- local nameCollect = self.collect[name]
- if not nameCollect then
- return DUMMY_FUNCTION
- end
-
- local scp = scope.getFolder(uri)
-
- if scp then
- return eachOfFolder(nameCollect, scp)
- end
-
- scp = scope.getLinkedScope(uri)
-
- if scp then
- return eachOfLinked(nameCollect, scp)
- end
-
- return eachOfFallback(nameCollect, scope.fallback)
-end
-
-local collectors = {}
-
-local function new()
- return setmetatable({
- collect = {},
- subscribed = {},
- }, mt)
-end
-
----@return collector
-return function (name)
- if name then
- collectors[name] = collectors[name] or new()
- return collectors[name]
- else
- return new()
- end
-end
diff --git a/script/core/color.lua b/script/core/color.lua
new file mode 100644
index 00000000..2cbcce11
--- /dev/null
+++ b/script/core/color.lua
@@ -0,0 +1,79 @@
+local files = require "files"
+local guide = require "parser.guide"
+
+local colorPattern = string.rep('%x', 8)
+---@param source parser.object
+---@return boolean
+local function isColor(source)
+ ---@type string
+ local text = source[1]
+ if text:len() ~= 8 then
+ return false
+ end
+ return text:match(colorPattern)
+end
+
+
+---@param colorText string
+---@return Color
+local function textToColor(colorText)
+ return {
+ alpha = tonumber(colorText:sub(1, 2), 16) / 255,
+ red = tonumber(colorText:sub(3, 4), 16) / 255,
+ green = tonumber(colorText:sub(5, 6), 16) / 255,
+ blue = tonumber(colorText:sub(7, 8), 16) / 255,
+ }
+end
+
+
+---@param color Color
+---@return string
+local function colorToText(color)
+ return string.format('%02X%02X%02X%02X'
+ , math.floor(color.alpha * 255)
+ , math.floor(color.red * 255)
+ , math.floor(color.green * 255)
+ , math.floor(color.blue * 255)
+ )
+end
+
+---@class Color
+---@field red number
+---@field green number
+---@field blue number
+---@field alpha number
+
+---@class ColorValue
+---@field color Color
+---@field start integer
+---@field finish integer
+
+---@async
+local function colors(uri)
+ local state = files.getState(uri)
+ local text = files.getText(uri)
+ if not state or not text then
+ return nil
+ end
+ ---@type ColorValue[]
+ local colorValues = {}
+
+ guide.eachSource(state.ast, function (source) ---@async
+ if source.type == 'string' and isColor(source) then
+ ---@type string
+ local colorText = source[1]
+
+ colorValues[#colorValues+1] = {
+ start = source.start + 1,
+ finish = source.finish - 1,
+ color = textToColor(colorText)
+ }
+ end
+ end)
+ return colorValues
+end
+
+return {
+ colors = colors,
+ colorToText = colorToText
+}
diff --git a/script/core/command/autoRequire.lua b/script/core/command/autoRequire.lua
index c0deecfc..32911d92 100644
--- a/script/core/command/autoRequire.lua
+++ b/script/core/command/autoRequire.lua
@@ -21,6 +21,9 @@ end
local function findInsertRow(uri)
local text = files.getText(uri)
local state = files.getState(uri)
+ if not state or not text then
+ return
+ end
local lines = state.lines
local fmt = {
pair = false,
@@ -68,7 +71,7 @@ local function askAutoRequire(uri, visiblePaths)
local selects = {}
local nameMap = {}
for _, visible in ipairs(visiblePaths) do
- local expect = visible.expect
+ local expect = visible.name
local select = lang.script(expect)
if not nameMap[select] then
nameMap[select] = expect
@@ -143,7 +146,7 @@ return function (data)
return
end
table.sort(visiblePaths, function (a, b)
- return #a.expect < #b.expect
+ return #a.name < #b.name
end)
local result = askAutoRequire(uri, visiblePaths)
diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua
index aa565f7f..992a0705 100644
--- a/script/core/command/removeSpace.lua
+++ b/script/core/command/removeSpace.lua
@@ -4,20 +4,12 @@ local proto = require 'proto'
local lang = require 'language'
local converter = require 'proto.converter'
-local function isInString(ast, offset)
- return guide.eachSourceContain(ast.ast, offset, function (source)
- if source.type == 'string' then
- return true
- end
- end) or false
-end
-
---@async
return function (data)
local uri = data.uri
local text = files.getText(uri)
local state = files.getState(uri)
- if not state then
+ if not state or not text then
return
end
@@ -32,7 +24,8 @@ return function (data)
goto NEXT_LINE
end
local lastPos = guide.offsetToPosition(state, lastOffset)
- if isInString(state.ast, lastPos) then
+ if guide.isInString(state.ast, lastPos)
+ or guide.isInComment(state.ast, lastPos) then
goto NEXT_LINE
end
local firstOffset = startOffset
diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua
index 8065aa9d..98ceaa58 100644
--- a/script/core/command/solve.lua
+++ b/script/core/command/solve.lua
@@ -32,7 +32,7 @@ return function (data)
local uri = data.uri
local text = files.getText(uri)
local state = files.getState(uri)
- if not state then
+ if not state or not text then
return
end
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua
index d4c20c60..8f28e450 100644
--- a/script/core/completion/completion.lua
+++ b/script/core/completion/completion.lua
@@ -18,6 +18,7 @@ local lookBackward = require 'core.look-backward'
local guide = require 'parser.guide'
local await = require 'await'
local postfix = require 'core.completion.postfix'
+local diag = require 'proto.diagnostic'
local diagnosticModes = {
'disable-next-line',
@@ -56,6 +57,7 @@ local function trim(str)
end
local function findNearestSource(state, position)
+ ---@type parser.object
local source
guide.eachSourceContain(state.ast, position, function (src)
source = src
@@ -66,6 +68,9 @@ end
local function findNearestTableField(state, position)
local uri = state.uri
local text = files.getText(uri)
+ if not text then
+ return nil
+ end
local offset = guide.positionToOffset(state, position)
local soffset = lookBackward.findAnyOffset(text, offset)
if not soffset then
@@ -155,36 +160,24 @@ local function buildFunctionSnip(source, value, oop)
if oop then
table.remove(args, 1)
end
- local len = #args
- local truncated = false
- if len > 0 and args[len]:match('^%s*%.%.%.:') then
- table.remove(args)
- truncated = true
- end
- for i = #args, 1, -1 do
- if args[i]:match('^%s*[^?]+%?:') then
- table.remove(args)
- truncated = true
- else
- break
- end
- end
local snipArgs = {}
for id, arg in ipairs(args) do
- local str = arg:gsub('^(%s*)(.+)', function (sp, word)
+ local str, count = arg:gsub('^(%s*)(%.%.%.)(.+)', function (sp, word)
return ('%s${%d:%s}'):format(sp, id, word)
end)
+ if count == 0 then
+ str = arg:gsub('^(%s*)([^:]+)(.+)', function (sp, word)
+ return ('%s${%d:%s}'):format(sp, id, word)
+ end)
+ end
table.insert(snipArgs, str)
end
- if truncated and #snipArgs == 0 then
- snipArgs = {'$1'}
- end
return ('%s(%s)'):format(name, table.concat(snipArgs, ', '))
end
local function buildDetail(source)
- local types = vm.getInfer(source):view()
+ local types = vm.getInfer(source):view(guide.getUri(source))
local literals = vm.getInfer(source):viewLiterals()
if literals then
return types .. ' = ' .. literals
@@ -204,6 +197,9 @@ local function getSnip(source)
local uri = guide.getUri(def)
local text = files.getText(uri)
local state = files.getState(uri)
+ if not state then
+ goto CONTINUE
+ end
local lines = state.lines
if not text then
goto CONTINUE
@@ -302,7 +298,7 @@ local function checkLocal(state, word, position, results)
if name:sub(1, 1) == '@' then
goto CONTINUE
end
- if vm.getInfer(source):hasFunction() then
+ if vm.getInfer(source):hasFunction(state.uri) then
local defs = vm.getDefs(source)
-- make sure `function` is before `doc.type.function`
local orders = {}
@@ -356,6 +352,7 @@ local function checkModule(state, word, position, results)
if not config.get(state.uri, 'Lua.completion.autoRequire') then
return
end
+ local globals = util.arrayToHash(config.get(state.uri, 'Lua.diagnostics.globals'))
local locals = guide.getVisibleLocals(state.ast, position)
for uri in files.eachFile(state.uri) do
if uri == guide.getUri(state.ast) then
@@ -366,7 +363,7 @@ local function checkModule(state, word, position, results)
local stemName = fileName:gsub('%..+', '')
if not locals[stemName]
and not vm.hasGlobalSets(state.uri, 'variable', stemName)
- and not config.get(state.uri, 'Lua.diagnostics.globals')[stemName]
+ and not globals[stemName]
and stemName:match '^[%a_][%w_]*$'
and matchKey(word, stemName) then
local targetState = files.getState(uri)
@@ -488,7 +485,7 @@ local function checkFieldFromFieldToIndex(state, name, src, parent, word, startP
end
local function checkFieldThen(state, name, src, word, startPos, position, parent, oop, results)
- local value = vm.getObjectValue(src) or src
+ local value = vm.getObjectFunctionValue(src) or src
local kind = define.CompletionItemKind.Field
if value.type == 'function'
or value.type == 'doc.type.function' then
@@ -512,7 +509,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent
})
return
end
- if oop and not vm.getInfer(src):hasFunction() then
+ if oop and not vm.getInfer(src):hasFunction(state.uri) then
return
end
local literal = guide.getLiteral(value)
@@ -568,7 +565,8 @@ local function checkFieldOfRefs(refs, state, word, startPos, position, parent, o
end
local funcLabel
if config.get(state.uri, 'Lua.completion.showParams') then
- local value = vm.getObjectValue(src) or src
+ --- TODO determine if getlocal should be a function here too
+ local value = vm.getObjectFunctionValue(src) or src
if value.type == 'function'
or value.type == 'doc.type.function' then
funcLabel = name .. getParams(value, oop)
@@ -916,24 +914,24 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
goto CONTINUE
end
local path = furi.decode(uri)
- local infos = rpath.getVisiblePath(uri, path)
+ local infos = rpath.getVisiblePath(myUri, path)
local relative = workspace.getRelativePath(path)
for _, info in ipairs(infos) do
- if matchKey(literal, info.expect) then
- if not collect[info.expect] then
- collect[info.expect] = {
+ if matchKey(literal, info.name) then
+ if not collect[info.name] then
+ collect[info.name] = {
textEdit = {
start = smark and (source.start + #smark) or position,
finish = smark and (source.finish - #smark) or position,
- newText = smark and info.expect or util.viewString(info.expect),
+ newText = smark and info.name or util.viewString(info.name),
},
path = relative,
}
end
if vm.isMetaFile(uri) then
- collect[info.expect][#collect[info.expect]+1] = ('* [[meta]](%s)'):format(uri)
+ collect[info.name][#collect[info.name]+1] = ('* [[meta]](%s)'):format(uri)
else
- collect[info.expect][#collect[info.expect]+1] = ([=[* [%s](%s) %s]=]):format(
+ collect[info.name][#collect[info.name]+1] = ([=[* [%s](%s) %s]=]):format(
relative,
uri,
lang.script('HOVER_USE_LUA_PATH', info.searcher)
@@ -1098,11 +1096,11 @@ local function tryLabelInString(label, source)
if not source or source.type ~= 'string' then
return label
end
- local state = parser.parse(label, 'String')
+ local state = parser.compile(label, 'String')
if not state or not state.ast then
return label
end
- if not matchKey(source[1], state.ast[1]) then
+ if not matchKey(source[1], state.ast[1]--[[@as string]]) then
return nil
end
return util.viewString(state.ast[1], source[2])
@@ -1124,18 +1122,112 @@ local function cleanEnums(enums, source)
return enums
end
-local function checkTypingEnum(state, position, defs, str, results)
+---@param state parser.state
+---@param pos integer
+---@param doc vm.node.object
+---@param enums table[]
+---@return table[]?
+local function insertDocEnum(state, pos, doc, enums)
+ local tbl = doc.bindSource
+ if not tbl then
+ return nil
+ end
+ local parent = tbl.parent
+ local parentName
+ if parent._globalNode then
+ parentName = parent._globalNode:getName()
+ else
+ local locals = guide.getVisibleLocals(state.ast, pos)
+ for _, loc in pairs(locals) do
+ if util.arrayHas(vm.getDefs(loc), tbl) then
+ parentName = loc[1]
+ break
+ end
+ end
+ end
+ local valueEnums = {}
+ for _, field in ipairs(tbl) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ if not field.value then
+ goto CONTINUE
+ end
+ local key = guide.getKeyName(field)
+ if not key then
+ goto CONTINUE
+ end
+ if field.value.type == 'integer'
+ or field.value.type == 'string' then
+ if parentName then
+ enums[#enums+1] = {
+ label = parentName .. '.' .. key,
+ kind = define.CompletionItemKind.EnumMember,
+ id = stack(function () ---@async
+ return {
+ detail = buildDetail(field),
+ description = buildDesc(field),
+ }
+ end),
+ }
+ end
+ valueEnums[#valueEnums+1] = {
+ label = util.viewLiteral(field.value[1]),
+ kind = define.CompletionItemKind.EnumMember,
+ id = stack(function () ---@async
+ return {
+ detail = buildDetail(field),
+ description = buildDesc(field),
+ }
+ end),
+ }
+ end
+ ::CONTINUE::
+ end
+ end
+ for _, enum in ipairs(valueEnums) do
+ enums[#enums+1] = enum
+ end
+ return enums
+end
+
+---@param state parser.state
+---@param pos integer
+---@param src vm.node.object
+---@param enums table[]
+---@param isInArray boolean?
+local function insertEnum(state, pos, src, enums, isInArray)
+ if src.type == 'doc.type.string'
+ or src.type == 'doc.type.integer'
+ or src.type == 'doc.type.boolean' then
+ ---@cast src parser.object
+ enums[#enums+1] = {
+ label = vm.viewObject(src, state.uri),
+ description = src.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ elseif src.type == 'doc.type.code' then
+ enums[#enums+1] = {
+ label = src[1],
+ description = src.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ elseif isInArray and src.type == 'doc.type.array' then
+ for i, d in ipairs(vm.getDefs(src.node)) do
+ insertEnum(state, pos, d, enums, isInArray)
+ end
+ elseif src.type == 'global' and src.cate == 'type' then
+ for _, set in ipairs(src:getSets(state.uri)) do
+ if set.type == 'doc.enum' then
+ insertDocEnum(state, pos, set, enums)
+ end
+ end
+ end
+end
+
+local function checkTypingEnum(state, position, defs, str, results, isInArray)
local enums = {}
for _, def in ipairs(defs) do
- if def.type == 'doc.type.string'
- or def.type == 'doc.type.integer'
- or def.type == 'doc.type.boolean' then
- enums[#enums+1] = {
- label = vm.viewObject(def),
- description = def.comment and def.comment.text,
- kind = define.CompletionItemKind.EnumMember,
- }
- end
+ insertEnum(state, position, def, enums, isInArray)
end
cleanEnums(enums, str)
for _, res in ipairs(enums) do
@@ -1143,7 +1235,7 @@ local function checkTypingEnum(state, position, defs, str, results)
end
end
-local function checkEqualEnumLeft(state, position, source, results)
+local function checkEqualEnumLeft(state, position, source, results, isInArray)
if not source then
return
end
@@ -1153,7 +1245,7 @@ local function checkEqualEnumLeft(state, position, source, results)
end
end)
local defs = vm.getDefs(source)
- checkTypingEnum(state, position, defs, str, results)
+ checkTypingEnum(state, position, defs, str, results, isInArray)
end
local function checkEqualEnum(state, position, results)
@@ -1197,15 +1289,24 @@ local function checkEqualEnumInString(state, position, results)
end
checkEqualEnumLeft(state, position, parent[1], results)
end
+ if (parent.type == 'tableexp') then
+ checkEqualEnumLeft(state, position, parent.parent.parent, results, true)
+ return
+ end
if parent.type == 'local' then
checkEqualEnumLeft(state, position, parent, results)
end
+
if parent.type == 'setlocal'
or parent.type == 'setglobal'
or parent.type == 'setfield'
or parent.type == 'setindex' then
checkEqualEnumLeft(state, position, parent.node, results)
end
+ if parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ checkEqualEnumLeft(state, position, parent, results)
+ end
end
local function isFuncArg(state, position)
@@ -1234,7 +1335,10 @@ local function tryIndex(state, position, results)
if not parent then
return
end
- local word = parent.next.index[1]
+ local word = parent.next and parent.next.index and parent.next.index[1]
+ if not word then
+ return
+ end
checkField(state, word, position, position, parent, oop, results)
end
@@ -1414,18 +1518,12 @@ local function tryCallArg(state, position, results)
if not node then
return
end
+
local enums = {}
for src in node:eachObject() do
- if src.type == 'doc.type.string'
- or src.type == 'doc.type.integer'
- or src.type == 'doc.type.boolean' then
- enums[#enums+1] = {
- label = vm.viewObject(src),
- description = src.comment,
- kind = define.CompletionItemKind.EnumMember,
- }
- end
+ insertEnum(state, position, src, enums, arg and arg.type == 'table')
if src.type == 'doc.type.function' then
+ ---@cast src parser.object
local insertText = buildInsertDocFunction(src)
local description
if src.comment then
@@ -1439,7 +1537,7 @@ local function tryCallArg(state, position, results)
: string()
end
enums[#enums+1] = {
- label = vm.getInfer(src):view(),
+ label = vm.getInfer(src):view(state.uri),
description = description,
kind = define.CompletionItemKind.Function,
insertText = insertText,
@@ -1467,6 +1565,7 @@ local function tryTable(state, position, results)
if source.type ~= 'table' then
tbl = source.parent
end
+
local defs = vm.getFields(tbl)
for _, field in ipairs(defs) do
local name = guide.getKeyName(field)
@@ -1478,9 +1577,28 @@ local function tryTable(state, position, results)
checkTableLiteralField(state, position, tbl, fields, results)
end
+local function tryArray(state, position, results)
+ local source = findNearestSource(state, position)
+ if not source then
+ return
+ end
+ if source.type ~= 'table' and (not source.parent or source.parent.type ~= 'table') then
+ return
+ end
+ local tbl = source
+ if source.type ~= 'table' then
+ tbl = source.parent
+ end
+ if source.parent.type == 'callargs' and source.parent.parent.type == 'call' then
+ return
+ end
+ -- { } inside when enum
+ checkEqualEnumLeft(state, position, tbl, results, true)
+end
+
local function getComment(state, position)
local offset = guide.positionToOffset(state, position)
- local symbolOffset = lookBackward.findAnyOffset(state.lua, offset)
+ local symbolOffset = lookBackward.findAnyOffset(state.lua, offset, true)
if not symbolOffset then
return
end
@@ -1493,9 +1611,9 @@ local function getComment(state, position)
return nil
end
-local function getluaDoc(state, position)
+local function getLuaDoc(state, position)
local offset = guide.positionToOffset(state, position)
- local symbolOffset = lookBackward.findAnyOffset(state.lua, offset)
+ local symbolOffset = lookBackward.findAnyOffset(state.lua, offset, true)
if not symbolOffset then
return
end
@@ -1528,11 +1646,15 @@ local function tryluaDocCate(word, results)
'async',
'nodiscard',
'cast',
+ 'operator',
+ 'source',
+ 'enum',
} do
if matchKey(word, docType) then
results[#results+1] = {
label = docType,
kind = define.CompletionItemKind.Event,
+ description = lang.script('LUADOC_DESC_' .. docType:upper())
}
end
end
@@ -1608,6 +1730,7 @@ local function tryluaDocBySource(state, position, source, results)
for _, doc in ipairs(vm.getDocSets(state.uri)) do
local name = (doc.type == 'doc.class' and doc.class[1])
or (doc.type == 'doc.alias' and doc.alias[1])
+ or (doc.type == 'doc.enum' and doc.enum[1])
if name
and not used[name]
and matchKey(source[1], name) then
@@ -1697,6 +1820,35 @@ local function tryluaDocBySource(state, position, source, results)
end
end
return true
+ elseif source.type == 'doc.operator.name' then
+ for _, name in ipairs(vm.UNARY_OP) do
+ if matchKey(source[1], name) then
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Operator,
+ description = ('```lua\n%s\n```'):format(vm.OP_UNARY_MAP[name]),
+ }
+ end
+ end
+ for _, name in ipairs(vm.BINARY_OP) do
+ if matchKey(source[1], name) then
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Operator,
+ description = ('```lua\n%s\n```'):format(vm.OP_BINARY_MAP[name]),
+ }
+ end
+ end
+ for _, name in ipairs(vm.OTHER_OP) do
+ if matchKey(source[1], name) then
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Operator,
+ description = ('```lua\n%s\n```'):format(vm.OP_OTHER_MAP[name]),
+ }
+ end
+ end
+ return true
end
return false
end
@@ -1734,6 +1886,14 @@ local function tryluaDocByErr(state, position, err, docState, results)
kind = define.CompletionItemKind.Class,
}
end
+ if doc.type == 'doc.enum'
+ and not used[doc.enum[1]] then
+ used[doc.enum[1]] = true
+ results[#results+1] = {
+ label = doc.enum[1],
+ kind = define.CompletionItemKind.Enum,
+ }
+ end
end
elseif err.type == 'LUADOC_MISS_PARAM_NAME' then
local funcs = {}
@@ -1783,7 +1943,7 @@ local function tryluaDocByErr(state, position, err, docState, results)
}
end
elseif err.type == 'LUADOC_MISS_DIAG_NAME' then
- for name in util.sortPairs(define.DiagnosticDefaultSeverity) do
+ for name in util.sortPairs(diag.getDiagAndErrNameMap()) do
results[#results+1] = {
label = name,
kind = define.CompletionItemKind.Value,
@@ -1807,6 +1967,28 @@ local function tryluaDocByErr(state, position, err, docState, results)
}
end
end
+ elseif err.type == 'LUADOC_MISS_OPERATOR_NAME' then
+ for _, name in ipairs(vm.UNARY_OP) do
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Operator,
+ description = ('```lua\n%s\n```'):format(vm.OP_UNARY_MAP[name]),
+ }
+ end
+ for _, name in ipairs(vm.BINARY_OP) do
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Operator,
+ description = ('```lua\n%s\n```'):format(vm.OP_BINARY_MAP[name]),
+ }
+ end
+ for _, name in ipairs(vm.OTHER_OP) do
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Operator,
+ description = ('```lua\n%s\n```'):format(vm.OP_OTHER_MAP[name]),
+ }
+ end
end
end
@@ -1818,14 +2000,14 @@ local function buildluaDocOfFunction(func)
local returns = {}
if func.args then
for _, arg in ipairs(func.args) do
- args[#args+1] = vm.getInfer(arg):view()
+ args[#args+1] = vm.getInfer(arg):view(guide.getUri(func))
end
end
if func.returns then
for _, rtns in ipairs(func.returns) do
for n = 1, #rtns do
if not returns[n] then
- returns[n] = vm.getInfer(rtns[n]):view()
+ returns[n] = vm.getInfer(rtns[n]):view(guide.getUri(func))
end
end
end
@@ -1853,25 +2035,27 @@ local function buildluaDocOfFunction(func)
end
local function tryluaDocOfFunction(doc, results)
- if not doc.bindSources then
+ if not doc.bindSource then
return
end
- local func
- for _, source in ipairs(doc.bindSources) do
- if source.type == 'function' then
- func = source
- break
- end
- end
+ local func = (doc.bindSource.type == 'function' and doc.bindSource)
+ or (doc.bindSource.value and doc.bindSource.value.type == 'function' and doc.bindSource.value)
+ or nil
if not func then
return
end
for _, otherDoc in ipairs(doc.bindGroup) do
- if otherDoc.type == 'doc.param'
- or otherDoc.type == 'doc.return' then
+ if otherDoc.type == 'doc.return' then
return
end
end
+ if func.args then
+ for _, param in ipairs(func.args) do
+ if param.bindDocs then
+ return
+ end
+ end
+ end
local insertText = buildluaDocOfFunction(func)
results[#results+1] = {
label = '@param;@return',
@@ -1882,8 +2066,8 @@ local function tryluaDocOfFunction(doc, results)
}
end
-local function tryluaDoc(state, position, results)
- local doc = getluaDoc(state, position)
+local function tryLuaDoc(state, position, results)
+ local doc = getLuaDoc(state, position)
if not doc then
return
end
@@ -1922,7 +2106,7 @@ local function tryComment(state, position, results)
return
end
local word = lookBackward.findWord(state.lua, guide.positionToOffset(state, position))
- local doc = getluaDoc(state, position)
+ local doc = getLuaDoc(state, position)
if not word then
local comment = getComment(state, position)
if not comment then
@@ -1961,7 +2145,7 @@ local function tryCompletions(state, position, triggerCharacter, results)
return
end
if getComment(state, position) then
- tryluaDoc(state, position, results)
+ tryLuaDoc(state, position, results)
tryComment(state, position, results)
return
end
@@ -1971,6 +2155,7 @@ local function tryCompletions(state, position, triggerCharacter, results)
trySpecial(state, position, results)
tryCallArg(state, position, results)
tryTable(state, position, results)
+ tryArray(state, position, results)
tryWord(state, position, triggerCharacter, results)
tryIndex(state, position, results)
trySymbol(state, position, results)
@@ -1983,8 +2168,6 @@ local function completion(uri, position, triggerCharacter)
return nil
end
clearStack()
- vm.lockCache()
- local _ <close> = vm.unlockCache
local results = {}
tracy.ZoneBeginN 'completion #2'
tryCompletions(state, position, triggerCharacter, results)
diff --git a/script/core/definition.lua b/script/core/definition.lua
index e4868532..866e8f84 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -4,6 +4,7 @@ local vm = require 'vm'
local findSource = require 'core.find-source'
local guide = require 'parser.guide'
local rpath = require 'workspace.require-path'
+local jumpSource = require 'core.jump-source'
local function sortResults(results)
-- 先按照顺序排序
@@ -54,6 +55,7 @@ local accept = {
['doc.see.name'] = true,
['doc.see.field'] = true,
['doc.cast.name'] = true,
+ ['doc.enum.name'] = true,
}
local function checkRequire(source, offset)
@@ -75,7 +77,7 @@ local function checkRequire(source, offset)
return nil
end
if libName == 'require' then
- return rpath.findUrisByRequirePath(guide.getUri(source), literal)
+ return rpath.findUrisByRequireName(guide.getUri(source), literal)
elseif libName == 'dofile'
or libName == 'loadfile' then
return workspace.findUrisByFilePath(literal)
@@ -169,8 +171,12 @@ return function (uri, offset)
if src.type == 'doc.alias' then
src = src.alias
end
+ if src.type == 'doc.enum' then
+ src = src.enum
+ end
if src.type == 'doc.class.name'
- or src.type == 'doc.alias.name' then
+ or src.type == 'doc.alias.name'
+ or src.type == 'doc.enum.name' then
if source.type ~= 'doc.type.name'
and source.type ~= 'doc.extends.name'
and source.type ~= 'doc.see.name' then
@@ -197,6 +203,7 @@ return function (uri, offset)
end
sortResults(results)
+ jumpSource(results)
return results
end
diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua
index f03f4361..830b2f2f 100644
--- a/script/core/diagnostics/ambiguity-1.lua
+++ b/script/core/diagnostics/ambiguity-1.lua
@@ -27,10 +27,10 @@ local literalMap = {
return function (uri, callback)
local state = files.getState(uri)
- if not state then
+ local text = files.getText(uri)
+ if not state or not text then
return
end
- local text = files.getText(uri)
guide.eachSourceType(state.ast, 'binary', function (source)
if source.op.type ~= 'or' then
return
diff --git a/script/core/diagnostics/assign-type-mismatch.lua b/script/core/diagnostics/assign-type-mismatch.lua
new file mode 100644
index 00000000..2d5c3f98
--- /dev/null
+++ b/script/core/diagnostics/assign-type-mismatch.lua
@@ -0,0 +1,117 @@
+local files = require 'files'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local await = require 'await'
+
+local checkTypes = {
+ 'local',
+ 'setlocal',
+ 'setglobal',
+ 'setfield',
+ 'setindex',
+ 'setmethod',
+ 'tablefield',
+ 'tableindex'
+}
+
+---@param source parser.object
+---@return boolean
+local function hasMarkType(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.type'
+ or doc.type == 'doc.class' then
+ return true
+ end
+ end
+ return false
+end
+
+---@param source parser.object
+---@return boolean
+local function hasMarkClass(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.class' then
+ return true
+ end
+ end
+ return false
+end
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@async
+ guide.eachSourceTypes(state.ast, checkTypes, function (source)
+ local value = source.value
+ if not value then
+ return
+ end
+ await.delay()
+ if source.type == 'setlocal' then
+ local locNode = vm.compileNode(source.node)
+ if not locNode:getData 'hasDefined' then
+ return
+ end
+ end
+ if value.type == 'nil' then
+ --[[
+ ---@class A
+ local mt
+ ---@type X
+ mt._x = nil -- don't warn this
+ ]]
+ if hasMarkType(source) then
+ return
+ end
+ if source.type == 'setfield'
+ or source.type == 'setindex' then
+ return
+ end
+ end
+
+ local valueNode = vm.compileNode(value)
+ if source.type == 'setindex' then
+ -- boolean[1] = nil
+ valueNode = valueNode:copy():removeOptional()
+ end
+
+ if value.type == 'getfield'
+ or value.type == 'getindex' then
+ -- 由于无法对字段进行类型收窄,
+ -- 因此将假值移除再进行检查
+ valueNode = valueNode:copy():setTruthy()
+ end
+
+ local varNode = vm.compileNode(source)
+ if vm.canCastType(uri, varNode, valueNode) then
+ return
+ end
+
+ -- local Cat = setmetatable({}, {__index = Animal}) 允许逆变
+ if hasMarkClass(source) then
+ if vm.canCastType(uri, valueNode:copy():remove 'table', varNode) then
+ return
+ end
+ end
+
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_ASSIGN_TYPE_MISMATCH', {
+ def = vm.getInfer(varNode):view(uri),
+ ref = vm.getInfer(valueNode):view(uri),
+ }),
+ }
+ end)
+end
diff --git a/script/core/diagnostics/cast-local-type.lua b/script/core/diagnostics/cast-local-type.lua
new file mode 100644
index 00000000..c3d6e1bb
--- /dev/null
+++ b/script/core/diagnostics/cast-local-type.lua
@@ -0,0 +1,50 @@
+local files = require 'files'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local await = require 'await'
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@async
+ guide.eachSourceType(state.ast, 'local', function (loc)
+ if not loc.ref then
+ return
+ end
+ await.delay()
+ local locNode = vm.compileNode(loc)
+ if not locNode:getData 'hasDefined' then
+ return
+ end
+ for _, ref in ipairs(loc.ref) do
+ if ref.type == 'setlocal' and ref.value then
+ await.delay()
+ local refNode = vm.compileNode(ref)
+ local value = ref.value
+
+ if value.type == 'getfield'
+ or value.type == 'getindex' then
+ -- 由于无法对字段进行类型收窄,
+ -- 因此将假值移除再进行检查
+ refNode = refNode:copy():setTruthy()
+ end
+
+ if not vm.canCastType(uri, locNode, refNode) then
+ callback {
+ start = ref.start,
+ finish = ref.finish,
+ message = lang.script('DIAG_CAST_LOCAL_TYPE', {
+ def = vm.getInfer(locNode):view(uri),
+ ref = vm.getInfer(refNode):view(uri),
+ }),
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/cast-type-mismatch.lua b/script/core/diagnostics/cast-type-mismatch.lua
new file mode 100644
index 00000000..a48e6cca
--- /dev/null
+++ b/script/core/diagnostics/cast-type-mismatch.lua
@@ -0,0 +1,45 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local vm = require 'vm'
+local await = require 'await'
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.cast' and doc.loc then
+ await.delay()
+ local defs = vm.getDefs(doc.loc)
+ local loc = defs[1]
+ if loc then
+ local defNode = vm.compileNode(loc)
+ if defNode:getData 'hasDefined' then
+ for _, cast in ipairs(doc.casts) do
+ if not cast.mode and cast.extends then
+ local refNode = vm.compileNode(cast.extends)
+ if not vm.canCastType(uri, defNode, refNode) then
+ callback {
+ start = cast.extends.start,
+ finish = cast.extends.finish,
+ message = lang.script('DIAG_CAST_TYPE_MISMATCH', {
+ def = vm.getInfer(defNode):view(uri),
+ ref = vm.getInfer(refNode):view(uri),
+ })
+ }
+ end
+ end
+ end
+ end
+ end
+ end
+ end
+end
diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua
index 40d4afeb..fcd2021d 100644
--- a/script/core/diagnostics/circle-doc-class.lua
+++ b/script/core/diagnostics/circle-doc-class.lua
@@ -2,7 +2,9 @@ local files = require 'files'
local lang = require 'language'
local vm = require 'vm'
local guide = require 'parser.guide'
+local await = require 'await'
+---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
@@ -18,6 +20,7 @@ return function (uri, callback)
if not doc.extends then
goto CONTINUE
end
+ await.delay()
local myName = guide.getKeyName(doc)
local list = { doc }
local mark = {}
diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua
index c97014fa..1a42b800 100644
--- a/script/core/diagnostics/close-non-object.lua
+++ b/script/core/diagnostics/close-non-object.lua
@@ -25,10 +25,11 @@ return function (uri, callback)
return
end
local infer = vm.getInfer(source.value)
- if not infer:hasClass()
- and not infer:hasType 'nil'
- and not infer:hasType 'table'
- and infer:view('any', uri) ~= 'any' then
+ if not infer:hasClass(uri)
+ and not infer:hasType(uri, 'nil')
+ and not infer:hasType(uri, 'table')
+ and not infer:hasUnknown(uri)
+ and not infer:hasAny(uri) then
callback {
start = source.value.start,
finish = source.value.finish,
diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua
index 21f7e83a..963fd9ed 100644
--- a/script/core/diagnostics/code-after-break.lua
+++ b/script/core/diagnostics/code-after-break.lua
@@ -2,7 +2,9 @@ local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
+local await = require 'await'
+---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
@@ -10,12 +12,14 @@ return function (uri, callback)
end
local mark = {}
+ ---@async
guide.eachSourceType(state.ast, 'break', function (source)
local list = source.parent
if mark[list] then
return
end
mark[list] = true
+ await.delay()
for i = #list, 1, -1 do
local src = list[i]
if src == source then
diff --git a/script/core/diagnostics/codestyle-check.lua b/script/core/diagnostics/codestyle-check.lua
index 34d55ee2..25603b4b 100644
--- a/script/core/diagnostics/codestyle-check.lua
+++ b/script/core/diagnostics/codestyle-check.lua
@@ -7,7 +7,7 @@ local pformatting = require 'provider.formatting'
---@async
return function(uri, callback)
- local text = files.getText(uri)
+ local text = files.getOriginText(uri)
if not text then
return
end
diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua
index 9bc4b273..bd6e5ee3 100644
--- a/script/core/diagnostics/count-down-loop.lua
+++ b/script/core/diagnostics/count-down-loop.lua
@@ -10,12 +10,15 @@ return function (uri, callback)
end
guide.eachSourceType(state.ast, 'loop', function (source)
- local maxNumer = source.max and tonumber(source.max[1])
- if maxNumer ~= 1 then
+ local maxNumber = source.max and tonumber(source.max[1])
+ if not maxNumber then
return
end
local minNumber = source.init and tonumber(source.init[1])
- if minNumber and minNumber <= 1 then
+ if minNumber and maxNumber and minNumber <= maxNumber then
+ return
+ end
+ if not minNumber and maxNumber ~= 1 then
return
end
if not source.step then
@@ -24,7 +27,7 @@ return function (uri, callback)
finish = source.max.finish,
message = lang.script('DIAG_COUNT_DOWN_LOOP'
, ('%s, %s'):format(text:sub(
- guide.positionToOffset(state, source.init.start),
+ guide.positionToOffset(state, source.init.start + 1),
guide.positionToOffset(state, source.max.finish)
), '-1')
)
@@ -37,7 +40,7 @@ return function (uri, callback)
finish = source.step.finish,
message = lang.script('DIAG_COUNT_DOWN_LOOP'
, ('%s, -%s'):format(text:sub(
- guide.positionToOffset(state, source.init.start),
+ guide.positionToOffset(state, source.init.start + 1),
guide.positionToOffset(state, source.max.finish)
), source.step[1])
)
diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua
index 27920c43..85ae2d19 100644
--- a/script/core/diagnostics/deprecated.lua
+++ b/script/core/diagnostics/deprecated.lua
@@ -15,7 +15,7 @@ return function (uri, callback)
return
end
- local dglobals = config.get(uri, 'Lua.diagnostics.globals')
+ local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
local rspecial = config.get(uri, 'Lua.runtime.special')
guide.eachSourceTypes(ast.ast, types, function (src) ---@async
diff --git a/script/core/diagnostics/different-requires.lua b/script/core/diagnostics/different-requires.lua
index de063c9f..22e3e681 100644
--- a/script/core/diagnostics/different-requires.lua
+++ b/script/core/diagnostics/different-requires.lua
@@ -21,7 +21,7 @@ return function (uri, callback)
return
end
local literal = arg1[1]
- local results = rpath.findUrisByRequirePath(uri, literal)
+ local results = rpath.findUrisByRequireName(uri, literal)
if not results or #results ~= 1 then
return
end
diff --git a/script/core/diagnostics/duplicate-doc-alias.lua b/script/core/diagnostics/duplicate-doc-alias.lua
index 3df6f972..360358e4 100644
--- a/script/core/diagnostics/duplicate-doc-alias.lua
+++ b/script/core/diagnostics/duplicate-doc-alias.lua
@@ -2,7 +2,9 @@ local files = require 'files'
local lang = require 'language'
local vm = require 'vm'
local guide = require 'parser.guide'
+local await = require 'await'
+---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
@@ -15,14 +17,20 @@ return function (uri, callback)
local cache = {}
for _, doc in ipairs(state.ast.docs) do
- if doc.type == 'doc.alias' then
+ if doc.type == 'doc.alias'
+ or doc.type == 'doc.enum' then
local name = guide.getKeyName(doc)
+ if not name then
+ return
+ end
+ await.delay()
if not cache[name] then
local docs = vm.getDocSets(uri, name)
cache[name] = {}
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.alias'
- or otherDoc.type == 'doc.class' then
+ or otherDoc.type == 'doc.class'
+ or otherDoc.type == 'doc.enum' then
cache[name][#cache[name]+1] = {
start = otherDoc.start,
finish = otherDoc.finish,
@@ -33,10 +41,10 @@ return function (uri, callback)
end
if #cache[name] > 1 then
callback {
- start = doc.alias.start,
- finish = doc.alias.finish,
+ start = (doc.alias or doc.enum).start,
+ finish = (doc.alias or doc.enum).finish,
related = cache,
- message = lang.script('DIAG_DUPLICATE_DOC_CLASS', name)
+ message = lang.script('DIAG_DUPLICATE_DOC_ALIAS', name)
}
end
end
diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua
index d4116b9b..a30dfa88 100644
--- a/script/core/diagnostics/duplicate-doc-field.lua
+++ b/script/core/diagnostics/duplicate-doc-field.lua
@@ -1,5 +1,7 @@
local files = require 'files'
local lang = require 'language'
+local vm = require 'vm.vm'
+local await = require 'await'
local function getFieldEventName(doc)
if not doc.extends then
@@ -28,6 +30,7 @@ local function getFieldEventName(doc)
return nil
end
+---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
@@ -45,7 +48,13 @@ return function (uri, callback)
mark = {}
elseif doc.type == 'doc.field' then
if mark then
- local name = ('%q'):format(doc.field[1])
+ await.delay()
+ local name
+ if doc.field.type == 'doc.type' then
+ name = ('[%s]'):format(vm.getInfer(doc.field):view(uri))
+ else
+ name = ('%q'):format(doc.field[1])
+ end
local eventName = getFieldEventName(doc)
if eventName then
name = name .. '|' .. eventName
diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua
index 5097ab3a..dfd9bd4b 100644
--- a/script/core/diagnostics/duplicate-index.lua
+++ b/script/core/diagnostics/duplicate-index.lua
@@ -2,14 +2,17 @@ local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
+local await = require 'await'
+---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
-
+ ---@async
guide.eachSourceType(ast.ast, 'table', function (source)
+ await.delay()
local mark = {}
for _, obj in ipairs(source) do
if obj.type == 'tablefield'
diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua
index 8052c420..ce67ab46 100644
--- a/script/core/diagnostics/duplicate-set-field.lua
+++ b/script/core/diagnostics/duplicate-set-field.lua
@@ -3,17 +3,21 @@ local lang = require 'language'
local define = require 'proto.define'
local guide = require 'parser.guide'
local vm = require 'vm'
+local await = require 'await'
+---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
+ ---@async
guide.eachSourceType(ast.ast, 'local', function (source)
if not source.ref then
return
end
+ await.delay()
local sets = {}
for _, ref in ipairs(source.ref) do
if ref.type ~= 'getlocal' then
@@ -48,10 +52,12 @@ return function (uri, callback)
local blocks = {}
for _, value in ipairs(values) do
local block = guide.getBlock(value)
- if not blocks[block] then
- blocks[block] = {}
+ if block then
+ if not blocks[block] then
+ blocks[block] = {}
+ end
+ blocks[block][#blocks[block]+1] = value
end
- blocks[block][#blocks[block]+1] = value
end
for _, defs in pairs(blocks) do
if #defs <= 1 then
diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua
index fc205d7e..e05b6aef 100644
--- a/script/core/diagnostics/empty-block.lua
+++ b/script/core/diagnostics/empty-block.lua
@@ -2,15 +2,18 @@ local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
+local await = require 'await'
--- 检查空代码块
+-- 检查空代码块
-- 但是排除忙等待(repeat/while)
+---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
+ await.delay()
guide.eachSourceType(ast.ast, 'if', function (source)
for _, block in ipairs(source) do
if #block > 0 then
@@ -24,6 +27,7 @@ return function (uri, callback)
message = lang.script.DIAG_EMPTY_BLOCK,
}
end)
+ await.delay()
guide.eachSourceType(ast.ast, 'loop', function (source)
if #source > 0 then
return
@@ -35,6 +39,7 @@ return function (uri, callback)
message = lang.script.DIAG_EMPTY_BLOCK,
}
end)
+ await.delay()
guide.eachSourceType(ast.ast, 'in', function (source)
if #source > 0 then
return
diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua
index 334fd81a..e154080c 100644
--- a/script/core/diagnostics/global-in-nil-env.lua
+++ b/script/core/diagnostics/global-in-nil-env.lua
@@ -2,65 +2,35 @@ local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
--- TODO: 检查路径是否可达
-local function mayRun(path)
- return true
-end
-
return function (uri, callback)
- local ast = files.getState(uri)
- if not ast then
- return
- end
- local root = guide.getRoot(ast.ast)
- local env = guide.getENV(root)
-
- local nilDefs = {}
- if not env or not env.ref then
- return
- end
- for _, ref in ipairs(env.ref) do
- if ref.type == 'setlocal' then
- if ref.value and ref.value.type == 'nil' then
- nilDefs[#nilDefs+1] = ref
- end
- end
- end
-
- if #nilDefs == 0 then
+ local state = files.getState(uri)
+ if not state then
return
end
local function check(source)
local node = source.node
if node.tag == '_ENV' then
- local ok
- for _, nilDef in ipairs(nilDefs) do
- local mode, pathA = guide.getPath(nilDef, source)
- if mode == 'before'
- and mayRun(pathA) then
- ok = nilDef
- break
- end
- end
- if ok then
- callback {
- start = source.start,
- finish = source.finish,
- uri = uri,
- message = lang.script.DIAG_GLOBAL_IN_NIL_ENV,
- related = {
- {
- start = ok.start,
- finish = ok.finish,
- uri = uri,
- }
+ return
+ end
+
+ if not node.value or node.value.type == 'nil' then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ uri = uri,
+ message = lang.script.DIAG_GLOBAL_IN_NIL_ENV,
+ related = {
+ {
+ start = node.start,
+ finish = node.finish,
+ uri = uri,
}
}
- end
+ }
end
end
- guide.eachSourceType(ast.ast, 'getglobal', check)
- guide.eachSourceType(ast.ast, 'setglobal', check)
+ guide.eachSourceType(state.ast, 'getglobal', check)
+ guide.eachSourceType(state.ast, 'setglobal', check)
end
diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua
index b4ae3715..c33de6ce 100644
--- a/script/core/diagnostics/init.lua
+++ b/script/core/diagnostics/init.lua
@@ -3,14 +3,22 @@ local define = require 'proto.define'
local config = require 'config'
local await = require 'await'
local vm = require "vm.vm"
+local util = require 'utility'
+local diagd = require 'proto.diagnostic'
-- 把耗时最长的诊断放到最后面
local diagSort = {
- ['redundant-value'] = 96,
- ['not-yieldable'] = 97,
- ['deprecated'] = 98,
- ['undefined-field'] = 99,
- ['redundant-parameter'] = 100,
+ ['redundant-value'] = 100,
+ ['not-yieldable'] = 100,
+ ['deprecated'] = 100,
+ ['undefined-field'] = 110,
+ ['redundant-parameter'] = 110,
+ ['cast-local-type'] = 120,
+ ['assign-type-mismatch'] = 120,
+ ['param-type-mismatch'] = 120,
+ ['missing-return'] = 120,
+ ['missing-return-value'] = 120,
+ ['redundant-return-value'] = 120,
}
local diagList = {}
@@ -46,30 +54,86 @@ local function checkSleep(uri, passed)
sleepRest = sleepRest - sleeped
end
+---@param uri uri
+---@param name string
+---@return string
+local function getSeverity(uri, name)
+ local severity = config.get(uri, 'Lua.diagnostics.severity')[name]
+ or define.DiagnosticDefaultSeverity[name]
+ if severity:sub(-1) == '!' then
+ return severity:sub(1, -2)
+ end
+ local groupSeverity = config.get(uri, 'Lua.diagnostics.groupSeverity')
+ local groups = diagd.getGroups(name)
+ local groupLevel = 999
+ for _, groupName in ipairs(groups) do
+ local gseverity = groupSeverity[groupName]
+ if gseverity and gseverity ~= 'Fallback' then
+ groupLevel = math.min(groupLevel, define.DiagnosticSeverity[gseverity])
+ end
+ end
+ if groupLevel == 999 then
+ return severity
+ end
+ for severityName, level in pairs(define.DiagnosticSeverity) do
+ if level == groupLevel then
+ return severityName
+ end
+ end
+ return severity
+end
+
+---@param uri uri
+---@param name string
+---@return string
+local function getStatus(uri, name)
+ local status = config.get(uri, 'Lua.diagnostics.neededFileStatus')[name]
+ or define.DiagnosticDefaultNeededFileStatus[name]
+ if status:sub(-1) == '!' then
+ return status:sub(1, -2)
+ end
+ local groupStatus = config.get(uri, 'Lua.diagnostics.groupFileStatus')
+ local groups = diagd.getGroups(name)
+ local groupLevel = 0
+ for _, groupName in ipairs(groups) do
+ local gstatus = groupStatus[groupName]
+ if gstatus and gstatus ~= 'Fallback' then
+ groupLevel = math.max(groupLevel, define.DiagnosticFileStatus[gstatus])
+ end
+ end
+ if groupLevel == 0 then
+ return status
+ end
+ for statusName, level in pairs(define.DiagnosticFileStatus) do
+ if level == groupLevel then
+ return statusName
+ end
+ end
+ return status
+end
+
---@async
---@param uri uri
---@param name string
---@param isScopeDiag boolean
---@param response async fun(result: any)
local function check(uri, name, isScopeDiag, response)
- if config.get(uri, 'Lua.diagnostics.disable')[name] then
+ local disables = config.get(uri, 'Lua.diagnostics.disable')
+ if util.arrayHas(disables, name) then
return
end
- local level = config.get(uri, 'Lua.diagnostics.severity')[name]
- or define.DiagnosticDefaultSeverity[name]
-
- local neededFileStatus = config.get(uri, 'Lua.diagnostics.neededFileStatus')[name]
- or define.DiagnosticDefaultNeededFileStatus[name]
+ local severity = getSeverity(uri, name)
+ local status = getStatus(uri, name)
- if neededFileStatus == 'None' then
+ if status == 'None' then
return
end
- if neededFileStatus == 'Opened' and not files.isOpen(uri) then
+ if status == 'Opened' and not files.isOpen(uri) then
return
end
- local severity = define.DiagnosticSeverity[level]
+ local level = define.DiagnosticSeverity[severity]
local clock = os.clock()
local mark = {}
---@async
@@ -85,7 +149,7 @@ local function check(uri, name, isScopeDiag, response)
end
mark[result.start] = true
- result.level = severity or result.level
+ result.level = level or result.level
result.code = name
response(result)
end, name)
diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua
index d03e8c70..68bec234 100644
--- a/script/core/diagnostics/lowercase-global.lua
+++ b/script/core/diagnostics/lowercase-global.lua
@@ -3,6 +3,7 @@ local guide = require 'parser.guide'
local lang = require 'language'
local config = require 'config'
local vm = require 'vm'
+local util = require 'utility'
local function isDocClass(source)
if not source.bindDocs then
@@ -23,10 +24,7 @@ return function (uri, callback)
return
end
- local definedGlobal = {}
- for name in pairs(config.get(uri, 'Lua.diagnostics.globals')) do
- definedGlobal[name] = true
- end
+ local definedGlobal = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
guide.eachSourceType(ast.ast, 'setglobal', function (source)
local name = guide.getKeyName(source)
diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua
index 698680ca..78b94a09 100644
--- a/script/core/diagnostics/missing-parameter.lua
+++ b/script/core/diagnostics/missing-parameter.lua
@@ -2,68 +2,27 @@ local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
+local await = require 'await'
-local function countCallArgs(source)
- local result = 0
- if not source.args then
- return 0
- end
- result = result + #source.args
- return result
-end
-
----@return integer
-local function countFuncArgs(source)
- if not source.args or #source.args == 0 then
- return 0
- end
- local count = 0
- for i = #source.args, 1, -1 do
- local arg = source.args[i]
- if arg.type ~= '...'
- and not (arg.name and arg.name[1] =='...')
- and not vm.compileNode(arg):isNullable() then
- return i
- end
- end
- return count
-end
-
-local function getFuncArgs(func)
- local funcArgs
- local defs = vm.getDefs(func)
- for _, def in ipairs(defs) do
- if def.type == 'function'
- or def.type == 'doc.type.function' then
- local args = countFuncArgs(def)
- if not funcArgs or args < funcArgs then
- funcArgs = args
- end
- end
- end
- return funcArgs
-end
-
+---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
+ ---@async
guide.eachSourceType(state.ast, 'call', function (source)
- local callArgs = countCallArgs(source)
+ await.delay()
+ local _, callArgs = vm.countList(source.args)
- local func = source.node
- local funcArgs = getFuncArgs(func)
+ local funcNode = vm.compileNode(source.node)
+ local funcArgs = vm.countParamsOfNode(funcNode)
- if not funcArgs then
+ if callArgs >= funcArgs then
return
end
- local delta = callArgs - funcArgs
- if delta >= 0 then
- return
- end
callback {
start = source.start,
finish = source.finish,
diff --git a/script/core/diagnostics/missing-return-value.lua b/script/core/diagnostics/missing-return-value.lua
new file mode 100644
index 00000000..2156d66c
--- /dev/null
+++ b/script/core/diagnostics/missing-return-value.lua
@@ -0,0 +1,66 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+local await = require 'await'
+
+local function hasDocReturn(func)
+ if not func.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(func.bindDocs) do
+ if doc.type == 'doc.return' then
+ return true
+ end
+ end
+ return false
+end
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@async
+ guide.eachSourceType(state.ast, 'function', function (source)
+ await.delay()
+ if not hasDocReturn(source) then
+ return
+ end
+ local min = vm.countReturnsOfFunction(source)
+ if min == 0 then
+ return
+ end
+ local returns = source.returns
+ if not returns then
+ return
+ end
+ for _, ret in ipairs(returns) do
+ local rmin, rmax = vm.countList(ret)
+ if rmax < min then
+ if rmin == rmax then
+ callback {
+ start = ret.start,
+ finish = ret.start + #'return',
+ message = lang.script('DIAG_MISSING_RETURN_VALUE', {
+ min = min,
+ rmax = rmax,
+ }),
+ }
+ else
+ callback {
+ start = ret.start,
+ finish = ret.start + #'return',
+ message = lang.script('DIAG_MISSING_RETURN_VALUE_RANGE', {
+ min = min,
+ rmin = rmin,
+ rmax = rmax,
+ }),
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/missing-return.lua b/script/core/diagnostics/missing-return.lua
new file mode 100644
index 00000000..e3539ac0
--- /dev/null
+++ b/script/core/diagnostics/missing-return.lua
@@ -0,0 +1,86 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+local await = require 'await'
+
+---@param block parser.object
+---@return boolean
+local function hasReturn(block)
+ if block.hasReturn or block.hasError then
+ return true
+ end
+ if block.type == 'if' then
+ local hasElse
+ for _, subBlock in ipairs(block) do
+ if not hasReturn(subBlock) then
+ return false
+ end
+ if subBlock.type == 'elseblock' then
+ hasElse = true
+ end
+ end
+ return hasElse == true
+ else
+ if block.type == 'while' then
+ if vm.testCondition(block.filter) then
+ return true
+ end
+ end
+ for _, action in ipairs(block) do
+ if guide.isBlockType(action) then
+ if hasReturn(action) then
+ return true
+ end
+ end
+ end
+ end
+ return false
+end
+
+---@param func parser.object
+---@return boolean
+local function isEmptyFunction(func)
+ if #func > 0 then
+ return false
+ end
+ local startRow = guide.rowColOf(func.start)
+ local finishRow = guide.rowColOf(func.finish)
+ return finishRow - startRow <= 1
+end
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@async
+ guide.eachSourceType(state.ast, 'function', function (source)
+ -- check declare only
+ if isEmptyFunction(source) then
+ return
+ end
+ await.delay()
+ if vm.countReturnsOfFunction(source, true) == 0 then
+ return
+ end
+ if hasReturn(source) then
+ return
+ end
+ local lastAction = source[#source]
+ local pos
+ if lastAction then
+ pos = lastAction.range or lastAction.finish
+ else
+ local row = guide.rowColOf(source.finish)
+ pos = guide.positionOf(row - 1, 0)
+ end
+ callback {
+ start = pos,
+ finish = pos,
+ message = lang.script('DIAG_MISSING_RETURN'),
+ }
+ end)
+end
diff --git a/script/core/diagnostics/need-check-nil.lua b/script/core/diagnostics/need-check-nil.lua
index 98fdfd08..9c86939a 100644
--- a/script/core/diagnostics/need-check-nil.lua
+++ b/script/core/diagnostics/need-check-nil.lua
@@ -2,14 +2,18 @@ local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
+local await = require 'await'
+---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
+ ---@async
guide.eachSourceType(state.ast, 'getlocal', function (src)
+ await.delay()
local checkNil
local nxt = src.next
if nxt then
@@ -24,11 +28,15 @@ return function (uri, callback)
if call and call.type == 'call' and call.node == src then
checkNil = true
end
+ local setIndex = src.parent
+ if setIndex and setIndex.type == 'setindex' and setIndex.index == src then
+ checkNil = true
+ end
if not checkNil then
return
end
local node = vm.compileNode(src)
- if node:hasFalsy() then
+ if node:hasFalsy() and not vm.getInfer(src):hasType(uri, 'any') then
callback {
start = src.start,
finish = src.finish,
diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua
index 669ed2bb..bd114959 100644
--- a/script/core/diagnostics/newfield-call.lua
+++ b/script/core/diagnostics/newfield-call.lua
@@ -1,16 +1,20 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
+local await = require 'await'
+local sub = require 'core.substring'
+---@async
return function (uri, callback)
- local ast = files.getState(uri)
- if not ast then
+ local state = files.getState(uri)
+ local text = files.getText(uri)
+ if not state or not text then
return
end
- local text = files.getText(uri)
-
- guide.eachSourceType(ast.ast, 'table', function (source)
+ ---@async
+ guide.eachSourceType(state.ast, 'table', function (source)
+ await.delay()
for i = 1, #source do
local field = source[i]
if field.type ~= 'tableexp' then
@@ -33,8 +37,8 @@ return function (uri, callback)
start = call.start,
finish = call.finish,
message = lang.script('DIAG_PREFIELD_CALL'
- , text:sub(func.start, func.finish)
- , text:sub(args.start, args.finish)
+ , sub(state)(func.start + 1, func.finish)
+ , sub(state)(args.start + 1, args.finish)
)
}
end
diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua
index 3f2d5ca5..2ba2ce03 100644
--- a/script/core/diagnostics/newline-call.lua
+++ b/script/core/diagnostics/newline-call.lua
@@ -1,14 +1,18 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
+local await = require 'await'
+local sub = require 'core.substring'
+---@async
return function (uri, callback)
local state = files.getState(uri)
local text = files.getText(uri)
- if not state then
+ if not state or not text then
return
end
+ ---@async
guide.eachSourceType(state.ast, 'call', function (source)
local node = source.node
local args = source.args
@@ -20,6 +24,9 @@ return function (uri, callback)
if not source.next then
return
end
+
+ await.delay()
+
local startOffset = guide.positionToOffset(state, args.start) + 1
local finishOffset = guide.positionToOffset(state, args.finish)
if text:sub(startOffset, startOffset) ~= '('
@@ -38,8 +45,8 @@ return function (uri, callback)
start = node.start,
finish = args.finish,
message = lang.script('DIAG_PREVIOUS_CALL'
- , text:sub(node.start, node.finish)
- , text:sub(args.start, args.finish)
+ , sub(state)(node.start + 1, node.finish)
+ , sub(state)(args.start + 1, args.finish)
),
}
end
diff --git a/script/core/diagnostics/no-unknown.lua b/script/core/diagnostics/no-unknown.lua
index 48aab5da..e706931a 100644
--- a/script/core/diagnostics/no-unknown.lua
+++ b/script/core/diagnostics/no-unknown.lua
@@ -2,25 +2,30 @@ local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local vm = require 'vm'
+local await = require 'await'
+local types = {
+ 'local',
+ 'setlocal',
+ 'setglobal',
+ 'getglobal',
+ 'setfield',
+ 'setindex',
+ 'tablefield',
+ 'tableindex',
+}
+
+---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
- guide.eachSource(ast.ast, function (source)
- if source.type ~= 'local'
- and source.type ~= 'setlocal'
- and source.type ~= 'setglobal'
- and source.type ~= 'getglobal'
- and source.type ~= 'setfield'
- and source.type ~= 'setindex'
- and source.type ~= 'tablefield'
- and source.type ~= 'tableindex' then
- return
- end
- if vm.getInfer(source):view() == 'unknown' then
+ ---@async
+ guide.eachSourceTypes(ast.ast, types, function (source)
+ await.delay()
+ if vm.getInfer(source):view(uri) == 'unknown' then
callback {
start = source.start,
finish = source.finish,
diff --git a/script/core/diagnostics/not-yieldable.lua b/script/core/diagnostics/not-yieldable.lua
index a1c84276..055025d4 100644
--- a/script/core/diagnostics/not-yieldable.lua
+++ b/script/core/diagnostics/not-yieldable.lua
@@ -11,7 +11,7 @@ local function isYieldAble(defs, i)
local arg = def.args and def.args[i]
if arg then
hasFuncDef = true
- if vm.getInfer(arg):hasType 'any'
+ if vm.getInfer(arg):hasType(guide.getUri(def), 'any')
or vm.isAsync(arg, true)
or arg.type == '...' then
return true
@@ -22,7 +22,7 @@ local function isYieldAble(defs, i)
local arg = def.args and def.args[i]
if arg then
hasFuncDef = true
- if vm.getInfer(arg.extends):hasType 'any'
+ if vm.getInfer(arg.extends):hasType(guide.getUri(def), 'any')
or vm.isAsync(arg.extends, true) then
return true
end
diff --git a/script/core/diagnostics/param-type-mismatch.lua b/script/core/diagnostics/param-type-mismatch.lua
new file mode 100644
index 00000000..6f34f579
--- /dev/null
+++ b/script/core/diagnostics/param-type-mismatch.lua
@@ -0,0 +1,72 @@
+local files = require 'files'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local await = require 'await'
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@param funcNode vm.node
+ ---@param i integer
+ ---@return vm.node?
+ local function getDefNode(funcNode, i)
+ local defNode = vm.createNode()
+ for f in funcNode:eachObject() do
+ if f.type == 'function'
+ or f.type == 'doc.type.function' then
+ local param = f.args and f.args[i]
+ if param then
+ defNode:merge(vm.compileNode(param))
+ if param[1] == '...' then
+ defNode:addOptional()
+ end
+ end
+ end
+ end
+ if defNode:isEmpty() then
+ return nil
+ end
+ return defNode
+ end
+
+ ---@async
+ guide.eachSourceType(state.ast, 'call', function (source)
+ if not source.args then
+ return
+ end
+ await.delay()
+ local funcNode = vm.compileNode(source.node)
+ for i, arg in ipairs(source.args) do
+ if i == 1 and source.node.type == 'getmethod' then
+ goto CONTINUE
+ end
+ local refNode = vm.compileNode(arg)
+ local defNode = getDefNode(funcNode, i)
+ if not defNode then
+ goto CONTINUE
+ end
+ if arg.type == 'getfield'
+ or arg.type == 'getindex' then
+ -- 由于无法对字段进行类型收窄,
+ -- 因此将假值移除再进行检查
+ refNode = refNode:copy():setTruthy()
+ end
+ if not vm.canCastType(uri, defNode, refNode) then
+ callback {
+ start = arg.start,
+ finish = arg.finish,
+ message = lang.script('DIAG_PARAM_TYPE_MISMATCH', {
+ def = vm.getInfer(defNode):view(uri),
+ ref = vm.getInfer(refNode):view(uri),
+ })
+ }
+ end
+ ::CONTINUE::
+ end
+ end)
+end
diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua
index 2157ae71..1fb3ca6b 100644
--- a/script/core/diagnostics/redefined-local.lua
+++ b/script/core/diagnostics/redefined-local.lua
@@ -1,18 +1,23 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
+local await = require 'await'
+---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
+
+ ---@async
guide.eachSourceType(ast.ast, 'local', function (source)
local name = source[1]
if name == '_'
or name == ast.ENVMode then
return
end
+ await.delay()
local exist = guide.getLocal(source, name, source.start-1)
if exist then
callback {
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
index 41781df8..9898d9bd 100644
--- a/script/core/diagnostics/redundant-parameter.lua
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -2,73 +2,48 @@ local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
+local await = require 'await'
-local function countCallArgs(source)
- local result = 0
- if not source.args then
- return 0
- end
- result = result + #source.args
- return result
-end
-
-local function countFuncArgs(source)
- if not source.args or #source.args == 0 then
- return 0
- end
- local lastArg = source.args[#source.args]
- if lastArg.type == '...'
- or (lastArg.name and lastArg.name[1] == '...') then
- return math.maxinteger
- else
- return #source.args
- end
-end
-
-local function getFuncArgs(func)
- local funcArgs
- local defs = vm.getDefs(func)
- for _, def in ipairs(defs) do
- if def.type == 'function'
- or def.type == 'doc.type.function' then
- local args = countFuncArgs(def)
- if not funcArgs or args > funcArgs then
- funcArgs = args
- end
- end
- end
- return funcArgs
-end
-
+---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
+ ---@async
guide.eachSourceType(state.ast, 'call', function (source)
- local callArgs = countCallArgs(source)
+ await.delay()
+ local callArgs = vm.countList(source.args)
if callArgs == 0 then
return
end
- local func = source.node
- local funcArgs = getFuncArgs(func)
+ local funcNode = vm.compileNode(source.node)
+ local _, funcArgs = vm.countParamsOfNode(funcNode)
- if not funcArgs then
- return
- end
-
- local delta = callArgs - funcArgs
- if delta <= 0 then
+ if callArgs <= funcArgs then
return
end
if callArgs == 1 and source.node.type == 'getmethod' then
return
end
- for i = #source.args - delta + 1, #source.args do
- local arg = source.args[i]
- if arg then
+ if funcArgs + 1 > #source.args then
+ local lastArg = source.args[#source.args]
+ if lastArg.type == 'call' and funcArgs > 0 then
+ -- 如果函数接收至少一个参数,那么调用方最后一个参数是函数调用
+ -- 导致的参数数量太多可以忽略。
+ -- 如果函数不接收任何参数,那么任何参数都是错误的。
+ return
+ end
+ callback {
+ start = lastArg.start,
+ finish = lastArg.finish,
+ message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs)
+ }
+ else
+ for i = funcArgs + 1, #source.args do
+ local arg = source.args[i]
callback {
start = arg.start,
finish = arg.finish,
diff --git a/script/core/diagnostics/redundant-return-value.lua b/script/core/diagnostics/redundant-return-value.lua
new file mode 100644
index 00000000..36432f98
--- /dev/null
+++ b/script/core/diagnostics/redundant-return-value.lua
@@ -0,0 +1,73 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+local await = require 'await'
+
+local function hasDocReturn(func)
+ if not func.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(func.bindDocs) do
+ if doc.type == 'doc.return' then
+ return true
+ end
+ end
+ return false
+end
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@async
+ guide.eachSourceType(state.ast, 'function', function (source)
+ await.delay()
+ if not hasDocReturn(source) then
+ return
+ end
+ local _, max = vm.countReturnsOfFunction(source)
+ local returns = source.returns
+ if not returns then
+ return
+ end
+ for _, ret in ipairs(returns) do
+ local rmin, rmax = vm.countList(ret)
+ if rmin > max then
+ for i = max + 1, #ret - 1 do
+ callback {
+ start = ret[i].start,
+ finish = ret[i].finish,
+ message = lang.script('DIAG_REDUNDANT_RETURN_VALUE', {
+ max = max,
+ rmax = i,
+ }),
+ }
+ end
+ if #ret == rmax then
+ callback {
+ start = ret[#ret].start,
+ finish = ret[#ret].finish,
+ message = lang.script('DIAG_REDUNDANT_RETURN_VALUE', {
+ max = max,
+ rmax = rmax,
+ }),
+ }
+ else
+ callback {
+ start = ret[#ret].start,
+ finish = ret[#ret].finish,
+ message = lang.script('DIAG_REDUNDANT_RETURN_VALUE_RANGE', {
+ max = max,
+ rmin = #ret,
+ rmax = rmax,
+ }),
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/return-type-mismatch.lua b/script/core/diagnostics/return-type-mismatch.lua
new file mode 100644
index 00000000..cce4aad8
--- /dev/null
+++ b/script/core/diagnostics/return-type-mismatch.lua
@@ -0,0 +1,76 @@
+local files = require 'files'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local await = require 'await'
+
+---@param func parser.object
+---@return vm.node[]?
+local function getDocReturns(func)
+ if not func.bindDocs then
+ return nil
+ end
+ local returns = {}
+ for _, doc in ipairs(func.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, ret in ipairs(doc.returns) do
+ returns[ret.returnIndex] = vm.compileNode(ret)
+ end
+ end
+ end
+ if #returns == 0 then
+ return nil
+ end
+ return returns
+end
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@param docReturns vm.node[]
+ ---@param rets parser.object
+ local function checkReturn(docReturns, rets)
+ for i, docRet in ipairs(docReturns) do
+ local retNode, exp = vm.selectNode(rets, i)
+ if not exp then
+ break
+ end
+ if retNode:hasName 'nil' then
+ if exp.type == 'getfield'
+ or exp.type == 'getindex' then
+ retNode = retNode:copy():removeOptional()
+ end
+ end
+ if not vm.canCastType(uri, docRet, retNode) then
+ callback {
+ start = exp.start,
+ finish = exp.finish,
+ message = lang.script('DIAG_RETURN_TYPE_MISMATCH', {
+ def = vm.getInfer(docRet):view(uri),
+ ref = vm.getInfer(retNode):view(uri),
+ index = i,
+ }),
+ }
+ end
+ end
+ end
+
+ ---@async
+ guide.eachSourceType(state.ast, 'function', function (source)
+ if not source.returns then
+ return
+ end
+ await.delay()
+ local docReturns = getDocReturns(source)
+ if not docReturns then
+ return
+ end
+ for _, ret in ipairs(source.returns) do
+ checkReturn(docReturns, ret)
+ await.delay()
+ end
+ end)
+end
diff --git a/script/core/diagnostics/spell-check.lua b/script/core/diagnostics/spell-check.lua
new file mode 100644
index 00000000..7369a235
--- /dev/null
+++ b/script/core/diagnostics/spell-check.lua
@@ -0,0 +1,34 @@
+local files = require 'files'
+local converter = require 'proto.converter'
+local log = require 'log'
+local spell = require 'provider.spell'
+
+
+---@async
+return function(uri, callback)
+ local text = files.getOriginText(uri)
+ if not text then
+ return
+ end
+
+ local status, diagnosticInfos = spell.spellCheck(uri, text)
+
+ if not status then
+ if diagnosticInfos ~= nil then
+ log.error(diagnosticInfos)
+ end
+
+ return
+ end
+
+ if diagnosticInfos then
+ for _, diagnosticInfo in ipairs(diagnosticInfos) do
+ callback {
+ start = converter.unpackPosition(uri, diagnosticInfo.range.start),
+ finish = converter.unpackPosition(uri, diagnosticInfo.range["end"]),
+ message = diagnosticInfo.message,
+ data = diagnosticInfo.data
+ }
+ end
+ end
+end
diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua
index cc51cf77..2e0398b2 100644
--- a/script/core/diagnostics/trailing-space.lua
+++ b/script/core/diagnostics/trailing-space.lua
@@ -1,25 +1,18 @@
local files = require 'files'
local lang = require 'language'
local guide = require 'parser.guide'
+local await = require 'await'
-local function isInString(ast, offset)
- local result = false
- guide.eachSourceType(ast, 'string', function (source)
- if offset >= source.start and offset <= source.finish then
- result = true
- end
- end)
- return result
-end
-
+---@async
return function (uri, callback)
local state = files.getState(uri)
- if not state then
+ local text = files.getText(uri)
+ if not state or not text then
return
end
- local text = files.getText(uri)
local lines = state.lines
for i = 0, #lines do
+ await.delay()
local startOffset = lines[i]
local finishOffset = text:find('[\r\n]', startOffset) or (#text + 1)
local lastOffset = finishOffset - 1
@@ -28,7 +21,8 @@ return function (uri, callback)
goto NEXT_LINE
end
local lastPos = guide.offsetToPosition(state, lastOffset)
- if isInString(state.ast, lastPos) then
+ if guide.isInString(state.ast, lastPos)
+ or guide.isInComment(state.ast, lastPos) then
goto NEXT_LINE
end
local firstOffset = startOffset
diff --git a/script/core/diagnostics/type-check.lua b/script/core/diagnostics/type-check.lua
deleted file mode 100644
index cc2b3228..00000000
--- a/script/core/diagnostics/type-check.lua
+++ /dev/null
@@ -1,3 +0,0 @@
----@async
-return function(uri, callback)
-end
diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua
index df71f0c9..c21ca993 100644
--- a/script/core/diagnostics/unbalanced-assignments.lua
+++ b/script/core/diagnostics/unbalanced-assignments.lua
@@ -2,7 +2,17 @@ local files = require 'files'
local define = require 'proto.define'
local lang = require 'language'
local guide = require 'parser.guide'
+local await = require 'await'
+local types = {
+ 'local',
+ 'setlocal',
+ 'setglobal',
+ 'setfield',
+ 'setindex' ,
+}
+
+---@async
return function (uri, callback, code)
local ast = files.getState(uri)
if not ast then
@@ -31,13 +41,9 @@ return function (uri, callback, code)
end
end
- guide.eachSource(ast.ast, function (source)
- if source.type == 'local'
- or source.type == 'setlocal'
- or source.type == 'setglobal'
- or source.type == 'setfield'
- or source.type == 'setindex' then
- checkSet(source)
- end
+ ---@async
+ guide.eachSourceTypes(ast.ast, types, function (source)
+ await.delay()
+ checkSet(source)
end)
end
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua
index 69edb380..bacd4288 100644
--- a/script/core/diagnostics/undefined-doc-name.lua
+++ b/script/core/diagnostics/undefined-doc-name.lua
@@ -32,7 +32,7 @@ return function (uri, callback)
return
end
local name = source[1]
- if name == '...' then
+ if name == '...' or name == '_' then
return
end
if #vm.getDocSets(uri, name) > 0
diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua
index 98919284..7a60a74f 100644
--- a/script/core/diagnostics/undefined-doc-param.lua
+++ b/script/core/diagnostics/undefined-doc-param.lua
@@ -1,21 +1,6 @@
local files = require 'files'
local lang = require 'language'
-local function hasParamName(func, name)
- if not func.args then
- return false
- end
- for _, arg in ipairs(func.args) do
- if arg[1] == name then
- return true
- end
- if arg.type == '...' and name == '...' then
- return true
- end
- end
- return false
-end
-
return function (uri, callback)
local state = files.getState(uri)
if not state then
@@ -27,26 +12,13 @@ return function (uri, callback)
end
for _, doc in ipairs(state.ast.docs) do
- if doc.type ~= 'doc.param' then
- goto CONTINUE
- end
- local binds = doc.bindSources
- if not binds then
- goto CONTINUE
- end
- local param = doc.param
- local name = param[1]
- for _, source in ipairs(binds) do
- if source.type == 'function' then
- if not hasParamName(source, name) then
- callback {
- start = param.start,
- finish = param.finish,
- message = lang.script('DIAG_UNDEFINED_DOC_PARAM', name)
- }
- end
- end
+ if doc.type == 'doc.param'
+ and not doc.bindSource then
+ callback {
+ start = doc.param.start,
+ finish = doc.param.finish,
+ message = lang.script('DIAG_UNDEFINED_DOC_PARAM', doc.param[1])
+ }
end
- ::CONTINUE::
end
end
diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua
index 2f559697..1dff575b 100644
--- a/script/core/diagnostics/undefined-env-child.lua
+++ b/script/core/diagnostics/undefined-env-child.lua
@@ -3,20 +3,40 @@ local guide = require 'parser.guide'
local lang = require 'language'
local vm = require "vm.vm"
+---@param source parser.object
+---@return boolean
+local function isBindDoc(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.type'
+ or doc.type == 'doc.class' then
+ return true
+ end
+ end
+ return false
+end
+
return function (uri, callback)
- local ast = files.getState(uri)
- if not ast then
+ local state = files.getState(uri)
+ if not state then
return
end
- guide.eachSourceType(ast.ast, 'getglobal', function (source)
- -- 单独验证自己是否在重载过的 _ENV 中有定义
+
+ guide.eachSourceType(state.ast, 'getglobal', function (source)
if source.node.tag == '_ENV' then
return
end
- local defs = vm.getDefs(source)
- if #defs > 0 then
+
+ if not isBindDoc(source.node) then
return
end
+
+ if #vm.getDefs(source) > 0 then
+ return
+ end
+
local key = source[1]
callback {
start = source.start,
diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua
index 41fcda48..a83241f5 100644
--- a/script/core/diagnostics/undefined-field.lua
+++ b/script/core/diagnostics/undefined-field.lua
@@ -34,11 +34,11 @@ return function (uri, callback)
local node = src.node
if node then
local ok
- for view in vm.getInfer(node):eachView() do
- if not skipCheckClass[view] then
- ok = true
- break
+ for view in vm.getInfer(node):eachView(uri) do
+ if skipCheckClass[view] then
+ return
end
+ ok = true
end
if not ok then
return
diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua
index bd0aae69..179c9204 100644
--- a/script/core/diagnostics/undefined-global.lua
+++ b/script/core/diagnostics/undefined-global.lua
@@ -4,6 +4,7 @@ local lang = require 'language'
local config = require 'config'
local guide = require 'parser.guide'
local await = require 'await'
+local util = require 'utility'
local requireLike = {
['include'] = true,
@@ -14,17 +15,17 @@ local requireLike = {
---@async
return function (uri, callback)
- local ast = files.getState(uri)
- if not ast then
+ local state = files.getState(uri)
+ if not state then
return
end
- local dglobals = config.get(uri, 'Lua.diagnostics.globals')
+ local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
local rspecial = config.get(uri, 'Lua.runtime.special')
local cache = {}
-- 遍历全局变量,检查所有没有 set 模式的全局变量
- guide.eachSourceType(ast.ast, 'getglobal', function (src) ---@async
+ guide.eachSourceType(state.ast, 'getglobal', function (src) ---@async
local key = src[1]
if not key then
return
@@ -40,6 +41,7 @@ return function (uri, callback)
return
end
if cache[key] == nil then
+ await.delay()
cache[key] = vm.hasGlobalSets(uri, 'variable', key)
end
if cache[key] then
diff --git a/script/core/diagnostics/unknown-cast-variable.lua b/script/core/diagnostics/unknown-cast-variable.lua
new file mode 100644
index 00000000..3f082a50
--- /dev/null
+++ b/script/core/diagnostics/unknown-cast-variable.lua
@@ -0,0 +1,32 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local vm = require 'vm'
+local await = require 'await'
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.cast' and doc.loc then
+ await.delay()
+ local defs = vm.getDefs(doc.loc)
+ local loc = defs[1]
+ if not loc then
+ callback {
+ start = doc.loc.start,
+ finish = doc.loc.finish,
+ message = lang.script('DIAG_UNKNOWN_CAST_VARIABLE', doc.loc[1])
+ }
+ end
+ end
+ end
+end
diff --git a/script/core/diagnostics/unknown-diag-code.lua b/script/core/diagnostics/unknown-diag-code.lua
index 9e492a29..07128a27 100644
--- a/script/core/diagnostics/unknown-diag-code.lua
+++ b/script/core/diagnostics/unknown-diag-code.lua
@@ -1,6 +1,6 @@
local files = require 'files'
local lang = require 'language'
-local define = require 'proto.define'
+local diag = require 'proto.diagnostic'
return function (uri, callback)
local state = files.getState(uri)
@@ -17,7 +17,7 @@ return function (uri, callback)
if doc.names then
for _, nameUnit in ipairs(doc.names) do
local code = nameUnit[1]
- if not define.DiagnosticDefaultSeverity[code] then
+ if not diag.getDiagAndErrNameMap()[code] then
callback {
start = nameUnit.start,
finish = nameUnit.finish,
diff --git a/script/core/diagnostics/unknown-operator.lua b/script/core/diagnostics/unknown-operator.lua
new file mode 100644
index 00000000..7404b5ef
--- /dev/null
+++ b/script/core/diagnostics/unknown-operator.lua
@@ -0,0 +1,36 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local vm = require 'vm'
+local await = require 'await'
+local util = require 'utility'
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.operator' then
+ local op = doc.op
+ if op then
+ local opName = op[1]
+ if not vm.OP_BINARY_MAP[opName]
+ and not vm.OP_UNARY_MAP[opName]
+ and not vm.OP_OTHER_MAP[opName] then
+ callback {
+ start = doc.op.start,
+ finish = doc.op.finish,
+ message = lang.script('DIAG_UNKNOWN_OPERATOR', opName)
+ }
+ end
+ end
+ end
+ end
+end
diff --git a/script/core/diagnostics/unreachable-code.lua b/script/core/diagnostics/unreachable-code.lua
new file mode 100644
index 00000000..4f0a38b7
--- /dev/null
+++ b/script/core/diagnostics/unreachable-code.lua
@@ -0,0 +1,84 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+local await = require 'await'
+local define = require 'proto.define'
+
+---@param source parser.object
+---@return boolean
+local function allLiteral(source)
+ local result = true
+ guide.eachSource(source, function (src)
+ if src.type ~= 'unary'
+ and src.type ~= 'binary'
+ and not guide.isLiteral(src) then
+ result = false
+ return false
+ end
+ end)
+ return result
+end
+
+---@param block parser.object
+---@return boolean
+local function hasReturn(block)
+ if block.hasReturn or block.hasError then
+ return true
+ end
+ if block.type == 'if' then
+ local hasElse
+ for _, subBlock in ipairs(block) do
+ if not hasReturn(subBlock) then
+ return false
+ end
+ if subBlock.type == 'elseblock' then
+ hasElse = true
+ end
+ end
+ return hasElse == true
+ else
+ if block.type == 'while' then
+ if vm.testCondition(block.filter)
+ and not block.breaks
+ and allLiteral(block.filter) then
+ return true
+ end
+ end
+ for _, action in ipairs(block) do
+ if guide.isBlockType(action) then
+ if hasReturn(action) then
+ return true
+ end
+ end
+ end
+ end
+ return false
+end
+
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@async
+ guide.eachSourceTypes(state.ast, {'main', 'function'}, function (source)
+ await.delay()
+ for i, action in ipairs(source) do
+ if guide.isBlockType(action)
+ and hasReturn(action) then
+ if i < #source then
+ callback {
+ start = source[i+1].start,
+ finish = source[#source].finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNREACHABLE_CODE'),
+ }
+ end
+ return
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua
index 813ac804..a873375f 100644
--- a/script/core/diagnostics/unused-function.lua
+++ b/script/core/diagnostics/unused-function.lua
@@ -18,7 +18,8 @@ local function isToBeClosed(source)
return false
end
----@param source parser.object
+---@param source parser.object?
+---@return boolean
local function isValidFunction(source)
if not source then
return false
@@ -55,7 +56,7 @@ local function collect(ast, white, roots, links)
for _, ref in ipairs(loc.ref or {}) do
if ref.type == 'getlocal' then
local func = guide.getParentFunction(ref)
- if not isValidFunction(func) or roots[func] then
+ if not func or not isValidFunction(func) or roots[func] then
roots[src] = true
return
end
diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua
index d12ceb2b..8f2ee217 100644
--- a/script/core/diagnostics/unused-local.lua
+++ b/script/core/diagnostics/unused-local.lua
@@ -3,6 +3,8 @@ local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
local vm = require 'vm.vm'
+local config = require 'config.config'
+local glob = require 'glob'
local function hasGet(loc)
if not loc.ref then
@@ -63,18 +65,24 @@ local function isDocClass(source)
return false
end
-local function isDocParam(source)
- if not source.bindDocs then
+---@param func parser.object
+---@return boolean
+local function isEmptyFunction(func)
+ if #func > 0 then
return false
end
- for _, doc in ipairs(source.bindDocs) do
- if doc.type == 'doc.param' then
- if doc.param[1] == source[1] then
- return true
- end
- end
+ local startRow = guide.rowColOf(func.start)
+ local finishRow = guide.rowColOf(func.finish)
+ return finishRow - startRow <= 1
+end
+
+---@param source parser.object
+local function isDeclareFunctionParam(source)
+ if source.parent.type ~= 'funcargs' then
+ return false
end
- return false
+ local func = source.parent.parent
+ return isEmptyFunction(func)
end
return function (uri, callback)
@@ -82,19 +90,24 @@ return function (uri, callback)
if not ast then
return
end
+ local ignorePatterns = config.get(uri, 'Lua.diagnostics.unusedLocalExclude')
+ local ignore = glob.glob(ignorePatterns)
guide.eachSourceType(ast.ast, 'local', function (source)
local name = source[1]
if name == '_'
or name == ast.ENVMode then
return
end
+ if ignore(name) then
+ return
+ end
if isToBeClosed(source) then
return
end
if isDocClass(source) then
return
end
- if vm.isMetaFile(uri) and isDocParam(source) then
+ if isDeclareFunctionParam(source) then
return
end
local data = hasGet(source)
diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua
index ce033cf3..08f12c4d 100644
--- a/script/core/diagnostics/unused-vararg.lua
+++ b/script/core/diagnostics/unused-vararg.lua
@@ -15,6 +15,9 @@ return function (uri, callback)
end
guide.eachSourceType(ast.ast, 'function', function (source)
+ if #source == 0 then
+ return
+ end
local args = source.args
if not args then
return
diff --git a/script/core/find-source.lua b/script/core/find-source.lua
index 26a411e5..99013b31 100644
--- a/script/core/find-source.lua
+++ b/script/core/find-source.lua
@@ -21,7 +21,7 @@ return function (ast, position, accept)
end
end
local start, finish = guide.getStartFinish(source)
- if finish - start < len and accept[source.type] then
+ if finish - start <= len and accept[source.type] then
result = source
len = finish - start
end
diff --git a/script/core/folding.lua b/script/core/folding.lua
index 4f93aed9..7f59636e 100644
--- a/script/core/folding.lua
+++ b/script/core/folding.lua
@@ -66,7 +66,8 @@ local care = {
['repeat'] = function (source, text, results)
local start = source.start
local finish = source.keyword[#source.keyword]
- if text:sub(finish - #'until' + 1, finish) ~= 'until' then
+ -- must end with 'until'
+ if #source.keyword ~= 4 then
return
end
local folding = {
@@ -143,6 +144,15 @@ local care = {
}
results[#results+1] = folding
end,
+ ['doc.alias'] = function (source, text, results)
+ local folding = {
+ start = source.start,
+ finish = source.bindGroup[#source.bindGroup].finish,
+ kind = 'comment',
+ hideLastLine = true,
+ }
+ results[#results+1] = folding
+ end,
}
---@async
diff --git a/script/core/formatting.lua b/script/core/formatting.lua
index b52854a4..fb5ca9c7 100644
--- a/script/core/formatting.lua
+++ b/script/core/formatting.lua
@@ -4,7 +4,10 @@ local log = require("log")
return function(uri, options)
local text = files.getOriginText(uri)
- local ast = files.getState(uri)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
local status, formattedText = codeFormat.format(uri, text, options)
if not status then
@@ -17,8 +20,8 @@ return function(uri, options)
return {
{
- start = ast.ast.start,
- finish = ast.ast.finish,
+ start = state.ast.start,
+ finish = state.ast.finish,
text = formattedText,
}
}
diff --git a/script/core/hint.lua b/script/core/hint.lua
index f97cdcec..767e531e 100644
--- a/script/core/hint.lua
+++ b/script/core/hint.lua
@@ -5,6 +5,7 @@ local guide = require 'parser.guide'
local await = require 'await'
local define = require 'proto.define'
local lang = require 'language'
+local substr = require 'core.substring'
---@async
local function typeHint(uri, results, start, finish)
@@ -38,7 +39,7 @@ local function typeHint(uri, results, start, finish)
end
end
await.delay()
- local view = vm.getInfer(source):view()
+ local view = vm.getInfer(source):view(uri)
if view == 'any'
or view == 'unknown'
or view == 'nil' then
@@ -189,24 +190,44 @@ local function arrayIndex(uri, results, start, finish)
end
---@async
- guide.eachSourceBetween(state.ast, start, finish, function (source)
- if source.type ~= 'tableexp' then
+ guide.eachSourceType(state.ast, 'table', function (source)
+ if source.finish < start or source.start > finish then
return
end
await.delay()
if option == 'Auto' then
- if not isMixedOrLargeTable(source.parent) then
+ if not isMixedOrLargeTable(source) then
return
end
end
- results[#results+1] = {
- text = ('[%d]'):format(source.tindex),
- offset = source.start,
- kind = define.InlayHintKind.Other,
- where = 'left',
- source = source.parent,
- }
+ local list = {}
+ local max = 0
+ for _, field in ipairs(source) do
+ if field.type == 'tableexp'
+ and field.start < finish
+ and field.finish > start then
+ list[#list+1] = field
+ if field.tindex > max then
+ max = field.tindex
+ end
+ end
+ end
+
+ if #list > 0 then
+ local length = #tostring(max)
+ local fmt = '[%0' .. length .. 'd]'
+ for _, field in ipairs(list) do
+ results[#results+1] = {
+ text = fmt:format(field.tindex),
+ offset = field.start,
+ kind = define.InlayHintKind.Other,
+ where = 'left',
+ source = field.parent,
+ }
+ end
+ end
end)
+
end
---@async
@@ -238,6 +259,72 @@ local function awaitHint(uri, results, start, finish)
end)
end
+local blockTypes = {
+ 'main',
+ 'function',
+ 'for',
+ 'loop',
+ 'in',
+ 'do',
+ 'repeat',
+ 'while',
+ 'ifblock',
+ 'elseifblock',
+ 'elseblock',
+}
+
+---@async
+local function semicolonHint(uri, results, start, finish)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+ local mode = config.get(uri, 'Lua.hint.semicolon')
+ if mode == 'Disable' then
+ return
+ end
+ local subber = substr(state)
+ ---@async
+ guide.eachSourceTypes(state.ast, blockTypes, function (src)
+ await.delay()
+ for i = 1, #src - 1 do
+ local current = src[i]
+ local next = src[i+1]
+ local left = current.finish
+ local right = next.start
+ local text = subber(left, right)
+ if mode == 'All' then
+ if not text:find '[,;]' then
+ results[#results+1] = {
+ text = ';',
+ offset = left,
+ kind = define.InlayHintKind.Other,
+ where = 'right',
+ }
+ end
+ elseif mode == 'SameLine' then
+ if not text:find '[,;\r\n]' then
+ results[#results+1] = {
+ text = ';',
+ offset = left,
+ kind = define.InlayHintKind.Other,
+ where = 'right',
+ }
+ end
+ end
+ end
+ if mode == 'All' then
+ local last = src[#src]
+ results[#results+1] = {
+ text = ';',
+ offset = last.range or last.finish,
+ kind = define.InlayHintKind.Other,
+ where = 'right',
+ }
+ end
+ end)
+end
+
---@async
return function (uri, start, finish)
local results = {}
@@ -245,5 +332,6 @@ return function (uri, start, finish)
paramName(uri, results, start, finish)
awaitHint(uri, results, start, finish)
arrayIndex(uri, results, start, finish)
+ semicolonHint(uri, results, start, finish)
return results
end
diff --git a/script/core/hover/args.lua b/script/core/hover/args.lua
index c485d9b9..bb4d4297 100644
--- a/script/core/hover/args.lua
+++ b/script/core/hover/args.lua
@@ -9,7 +9,7 @@ local function asFunction(source)
methodDef = true
end
if methodDef then
- args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view 'any')
+ args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view(guide.getUri(source), 'any'))
end
if source.args then
for i = 1, #source.args do
@@ -29,15 +29,15 @@ local function asFunction(source)
args[#args+1] = ('%s%s: %s'):format(
name,
optional and '?' or '',
- vm.getInfer(argNode):view('any', guide.getUri(source))
+ vm.getInfer(argNode):view(guide.getUri(source), 'any')
)
elseif arg.type == '...' then
- args[#args+1] = ('%s: %s'):format(
+ args[#args+1] = ('%s%s'):format(
'...',
- vm.getInfer(arg):view 'any'
+ vm.getInfer(arg):view(guide.getUri(source), 'any')
)
else
- args[#args+1] = ('%s'):format(vm.getInfer(arg):view 'any')
+ args[#args+1] = ('%s'):format(vm.getInfer(arg):view(guide.getUri(source), 'any'))
end
::CONTINUE::
end
@@ -46,17 +46,17 @@ local function asFunction(source)
end
local function asDocFunction(source)
+ local args = {}
if not source.args then
- return ''
+ return args
end
- local args = {}
for i = 1, #source.args do
local arg = source.args[i]
local name = arg.name[1]
args[i] = ('%s%s: %s'):format(
name,
arg.optional and '?' or '',
- arg.extends and vm.getInfer(arg.extends):view 'any' or 'any'
+ arg.extends and vm.getInfer(arg.extends):view(guide.getUri(source), 'any') or 'any'
)
end
return args
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
index e9267c0f..e11dd6c8 100644
--- a/script/core/hover/description.lua
+++ b/script/core/hover/description.lua
@@ -6,11 +6,12 @@ local lang = require 'language'
local util = require 'utility'
local guide = require 'parser.guide'
local rpath = require 'workspace.require-path'
+local furi = require 'file-uri'
local function collectRequire(mode, literal, uri)
local result, searchers
if mode == 'require' then
- result, searchers = rpath.findUrisByRequirePath(uri, literal)
+ result, searchers = rpath.findUrisByRequireName(uri, literal)
elseif mode == 'dofile'
or mode == 'loadfile' then
result = ws.findUrisByFilePath(literal)
@@ -82,7 +83,53 @@ local function asString(source)
or asStringView(source, literal)
end
-local function getBindComment(source, docGroup, base)
+---@param comment string
+---@param suri uri
+---@return string?
+local function normalizeComment(comment, suri)
+ if not comment then
+ return nil
+ end
+ if comment:sub(1, 1) == '-' then
+ comment = comment:sub(2)
+ end
+ if comment:sub(1, 1) == '@' then
+ return nil
+ end
+ comment = comment:gsub('(%[.-%]%()(.-)(%))', function (left, path, right)
+ local scheme = furi.split(path)
+ if scheme
+ -- strange way to check `C:/xxx.lua`
+ and #scheme > 1 then
+ return
+ end
+ local absPath = ws.getAbsolutePath(suri:gsub('/[^/]+$', ''), path)
+ if not absPath then
+ return
+ end
+ local uri = furi.encode(absPath)
+ return left .. uri .. right
+ end)
+ return comment
+end
+
+local function getBindComment(source)
+ local uri = guide.getUri(source)
+ local lines = {}
+ for _, docComment in ipairs(source.bindComments) do
+ lines[#lines+1] = normalizeComment(docComment.comment.text, uri)
+ end
+ if not lines or #lines == 0 then
+ return nil
+ end
+ return table.concat(lines, '\n')
+end
+
+local function lookUpDocComments(source)
+ local docGroup = source.bindDocs
+ if not docGroup then
+ return
+ end
if source.type == 'setlocal'
or source.type == 'getlocal' then
source = source.node
@@ -90,34 +137,23 @@ local function getBindComment(source, docGroup, base)
if source.parent.type == 'funcargs' then
return
end
- local continue
- local lines
+ local uri = guide.getUri(source)
+ local lines = {}
for _, doc in ipairs(docGroup) do
if doc.type == 'doc.comment' then
- if not continue then
- continue = true
- lines = {}
+ lines[#lines+1] = normalizeComment(doc.comment.text, uri)
+ elseif doc.type == 'doc.type' then
+ if doc.comment then
+ lines[#lines+1] = normalizeComment(doc.comment.text, uri)
end
- if doc.comment.text:sub(1, 1) == '-' then
- lines[#lines+1] = doc.comment.text:sub(2)
- else
- lines[#lines+1] = doc.comment.text
- end
- elseif doc == base then
- break
- else
- continue = false
- if doc.type == 'doc.field'
- or doc.type == 'doc.class' then
- lines = nil
+ elseif doc.type == 'doc.class' then
+ for _, docComment in ipairs(doc.bindComments) do
+ lines[#lines+1] = normalizeComment(docComment.comment.text, uri)
end
end
end
if source.comment then
- if not lines then
- lines = {}
- end
- lines[#lines+1] = source.comment.text
+ lines[#lines+1] = normalizeComment(source.comment.text, uri)
end
if not lines or #lines == 0 then
return nil
@@ -128,8 +164,9 @@ end
local function tryDocClassComment(source)
for _, def in ipairs(vm.getDefs(source)) do
if def.type == 'doc.class'
- or def.type == 'doc.alias' then
- local comment = getBindComment(def, def.bindGroup, def)
+ or def.type == 'doc.alias'
+ or def.type == 'doc.enum' then
+ local comment = getBindComment(def)
if comment then
return comment
end
@@ -144,7 +181,7 @@ local function tryDocModule(source)
return collectRequire('require', source.module, guide.getUri(source))
end
-local function buildEnumChunk(docType, name)
+local function buildEnumChunk(docType, name, uri)
if not docType then
return nil
end
@@ -152,10 +189,11 @@ local function buildEnumChunk(docType, name)
local types = {}
local lines = {}
for _, tp in ipairs(vm.getDefs(docType)) do
- types[#types+1] = vm.getInfer(tp):view()
+ types[#types+1] = vm.getInfer(tp):view(guide.getUri(docType))
if tp.type == 'doc.type.string'
or tp.type == 'doc.type.integer'
- or tp.type == 'doc.type.boolean' then
+ or tp.type == 'doc.type.boolean'
+ or tp.type == 'doc.type.code' then
enums[#enums+1] = tp
end
local comment = tryDocClassComment(tp)
@@ -174,7 +212,7 @@ local function buildEnumChunk(docType, name)
(enum.default and '->')
or (enum.additional and '+>')
or ' |',
- vm.viewObject(enum)
+ vm.viewObject(enum, uri)
)
if enum.comment then
local first = true
@@ -198,26 +236,33 @@ local function getBindEnums(source, docGroup)
return
end
+ local uri = guide.getUri(source)
local mark = {}
local chunks = {}
local returnIndex = 0
for _, doc in ipairs(docGroup) do
if doc.type == 'doc.param' then
local name = doc.param[1]
+ if name == '...' then
+ name = '...(param)'
+ end
if mark[name] then
goto CONTINUE
end
mark[name] = true
- chunks[#chunks+1] = buildEnumChunk(doc.extends, name)
+ chunks[#chunks+1] = buildEnumChunk(doc.extends, name, uri)
elseif doc.type == 'doc.return' then
for _, rtn in ipairs(doc.returns) do
returnIndex = returnIndex + 1
local name = rtn.name and rtn.name[1] or ('return #%d'):format(returnIndex)
+ if name == '...' then
+ name = '...(return)'
+ end
if mark[name] then
goto CONTINUE
end
mark[name] = true
- chunks[#chunks+1] = buildEnumChunk(rtn, name)
+ chunks[#chunks+1] = buildEnumChunk(rtn, name, uri)
end
end
::CONTINUE::
@@ -228,37 +273,38 @@ local function getBindEnums(source, docGroup)
return table.concat(chunks, '\n\n')
end
-local function tryDocFieldUpComment(source)
- if source.type ~= 'doc.field.name' then
+local function tryDocFieldComment(source)
+ if source.type ~= 'doc.field' then
return
end
- local docField = source.parent
- if not docField.bindGroup then
- return
+ if source.comment then
+ return normalizeComment(source.comment.text, guide.getUri(source))
+ end
+ if source.bindGroup then
+ return getBindComment(source)
end
- local comment = getBindComment(docField, docField.bindGroup, docField)
- return comment
end
local function getFunctionComment(source)
local docGroup = source.bindDocs
+ if not docGroup then
+ return
+ end
local hasReturnComment = false
- for _, doc in ipairs(docGroup) do
+ for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.return' and doc.comment then
hasReturnComment = true
break
end
end
+ local uri = guide.getUri(source)
local md = markdown()
for _, doc in ipairs(docGroup) do
if doc.type == 'doc.comment' then
- if doc.comment.text:sub(1, 1) == '-' then
- md:add('md', doc.comment.text:sub(2))
- else
- md:add('md', doc.comment.text)
- end
+ local comment = normalizeComment(doc.comment.text, uri)
+ md:add('md', comment)
elseif doc.type == 'doc.param' then
if doc.comment then
md:add('md', ('@*param* `%s` — %s'):format(
@@ -295,18 +341,36 @@ local function getFunctionComment(source)
local enums = getBindEnums(source, docGroup)
md:add('lua', enums)
- return md
+
+ local comment = md:string()
+ if comment == '' then
+ return nil
+ end
+ return comment
end
local function tryDocComment(source)
- if not source.bindDocs then
- return
+ local md = markdown()
+ if source.type == 'function' then
+ local comment = getFunctionComment(source)
+ md:add('md', comment)
+ source = source.parent
end
- if source.type ~= 'function' then
- local comment = getBindComment(source, source.bindDocs)
- return comment
+ local comment = lookUpDocComments(source)
+ md:add('md', comment)
+ if source.type == 'doc.alias' then
+ local enums = buildEnumChunk(source, source.alias[1], guide.getUri(source))
+ md:add('lua', enums)
end
- return getFunctionComment(source)
+ if source.type == 'doc.enum' then
+ local enums = buildEnumChunk(source, source.enum[1], guide.getUri(source))
+ md:add('lua', enums)
+ end
+ local result = md:string()
+ if result == '' then
+ return nil
+ end
+ return result
end
local function tryDocOverloadToComment(source)
@@ -315,14 +379,12 @@ local function tryDocOverloadToComment(source)
end
local doc = source.parent
if doc.type ~= 'doc.overload'
- or not doc.bindSources then
+ or not doc.bindSource then
return
end
- for _, src in ipairs(doc.bindSources) do
- local md = tryDocComment(src)
- if md then
- return md
- end
+ local md = tryDocComment(doc.bindSource)
+ if md then
+ return md
end
end
@@ -350,6 +412,45 @@ local function tyrDocParamComment(source)
end
end
+---@param source parser.object
+local function tryDocEnum(source)
+ if source.type ~= 'doc.enum' then
+ return
+ end
+ local tbl = source.bindSource
+ if not tbl then
+ return
+ end
+ local md = markdown()
+ md:add('lua', '{')
+ for _, field in ipairs(tbl) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ if not field.value then
+ goto CONTINUE
+ end
+ local key = guide.getKeyName(field)
+ if not key then
+ goto CONTINUE
+ end
+ if field.value.type == 'integer'
+ or field.value.type == 'string' then
+ md:add('lua', (' %s: %s = %s,'):format(key, field.value.type, field.value[1]))
+ end
+ if field.value.type == 'binary'
+ or field.value.type == 'unary' then
+ local number = vm.getNumber(field.value)
+ if number then
+ md:add('lua', (' %s: %s = %s,'):format(key, math.tointeger(number) and 'integer' or 'number', number))
+ end
+ end
+ ::CONTINUE::
+ end
+ end
+ md:add('lua', '}')
+ return md:string()
+end
+
return function (source)
if source.type == 'string' then
return asString(source)
@@ -358,9 +459,10 @@ return function (source)
source = source.parent
end
return tryDocOverloadToComment(source)
- or tryDocFieldUpComment(source)
+ or tryDocFieldComment(source)
or tyrDocParamComment(source)
or tryDocComment(source)
or tryDocClassComment(source)
or tryDocModule(source)
+ or tryDocEnum(source)
end
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
index 7231944a..5a65cbce 100644
--- a/script/core/hover/init.lua
+++ b/script/core/hover/init.lua
@@ -39,7 +39,7 @@ local function getHover(source)
end
local oop
- if vm.getInfer(source):view() == 'function' then
+ if vm.getInfer(source):view(guide.getUri(source)) == 'function' then
local defs = vm.getDefs(source)
-- make sure `function` is before `doc.type.function`
local orders = {}
@@ -92,19 +92,21 @@ local function getHover(source)
end
local accept = {
- ['local'] = true,
- ['setlocal'] = true,
- ['getlocal'] = true,
- ['setglobal'] = true,
- ['getglobal'] = true,
- ['field'] = true,
- ['method'] = true,
- ['string'] = true,
- ['number'] = true,
- ['integer'] = true,
- ['doc.type.name'] = true,
- ['function'] = true,
- ['doc.module'] = true,
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['string'] = true,
+ ['number'] = true,
+ ['integer'] = true,
+ ['doc.type.name'] = true,
+ ['doc.class.name'] = true,
+ ['doc.enum.name'] = true,
+ ['function'] = true,
+ ['doc.module'] = true,
}
---@async
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
index 2bbfe806..5c502ec1 100644
--- a/script/core/hover/label.lua
+++ b/script/core/hover/label.lua
@@ -33,7 +33,10 @@ local function asDocTypeName(source)
return '(class) ' .. doc.class[1]
end
if doc.type == 'doc.alias' then
- return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view())
+ return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view(guide.getUri(source)))
+ end
+ if doc.type == 'doc.enum' then
+ return '(enum) ' .. doc.enum[1]
end
end
end
@@ -42,7 +45,7 @@ end
local function asValue(source, title)
local name = buildName(source, false) or ''
local ifr = vm.getInfer(source)
- local type = ifr:view()
+ local type = ifr:view(guide.getUri(source))
local literal = ifr:viewLiterals()
local cont = buildTable(source)
local pack = {}
@@ -55,10 +58,11 @@ local function asValue(source, title)
and ( type == 'table'
or type == 'any'
or type == 'unknown'
- or type == 'nil') then
- type = nil
+ or type == 'nil'
+ or type:sub(1, 1) == '{') then
+ else
+ pack[#pack+1] = type
end
- pack[#pack+1] = type
if literal then
pack[#pack+1] = '='
pack[#pack+1] = literal
@@ -139,7 +143,7 @@ local function asDocFieldName(source)
break
end
end
- local view = vm.getInfer(source.extends):view()
+ local view = vm.getInfer(source.extends):view(guide.getUri(source))
if not class then
return ('(field) ?.%s: %s'):format(name, view)
end
@@ -212,7 +216,8 @@ return function (source, oop)
elseif source.type == 'number'
or source.type == 'integer' then
return asNumber(source)
- elseif source.type == 'doc.type.name' then
+ elseif source.type == 'doc.type.name'
+ or source.type == 'doc.enum.name' then
return asDocTypeName(source)
elseif source.type == 'doc.field' then
return asDocFieldName(source)
diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua
index f8473638..3fabfb89 100644
--- a/script/core/hover/name.lua
+++ b/script/core/hover/name.lua
@@ -20,6 +20,9 @@ local function asField(source, oop)
local class
if source.node.type ~= 'getglobal' then
class = vm.getInfer(source.node):viewClass()
+ if class == 'any' or class == 'unknown' then
+ class = nil
+ end
end
local node = class
or buildName(source.node, false)
@@ -47,14 +50,12 @@ end
local function asDocFunction(source, oop)
local doc = guide.getParentType(source, 'doc.type')
or guide.getParentType(source, 'doc.overload')
- if not doc or not doc.bindSources then
+ if not doc or not doc.bindSource then
return ''
end
- for _, src in ipairs(doc.bindSources) do
- local name = buildName(src, oop)
- if name ~= '' then
- return name
- end
+ local name = buildName(doc.bindSource, oop)
+ if name ~= '' then
+ return name
end
return ''
end
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
index 3d8a94a5..b71b9e5d 100644
--- a/script/core/hover/return.lua
+++ b/script/core/hover/return.lua
@@ -1,34 +1,5 @@
local vm = require 'vm.vm'
-
----@param source parser.object
----@return integer
-local function countReturns(source)
- local n = 0
-
- local docs = source.bindDocs
- if docs then
- for _, doc in ipairs(docs) do
- if doc.type == 'doc.return' then
- for _, rtn in ipairs(doc.returns) do
- if rtn.returnIndex and rtn.returnIndex > n then
- n = rtn.returnIndex
- end
- end
- end
- end
- end
-
- local returns = source.returns
- if returns then
- for _, rtn in ipairs(returns) do
- if #rtn > n then
- n = #rtn
- end
- end
- end
-
- return n
-end
+local guide = require 'parser.guide'
---@param source parser.object
---@return parser.object[]
@@ -50,7 +21,7 @@ local function getReturnDocs(source)
end
local function asFunction(source)
- local num = countReturns(source)
+ local _, _, num = vm.countReturnsOfFunction(source)
if num == 0 then
return nil
end
@@ -62,11 +33,14 @@ local function asFunction(source)
for i = 1, num do
local rtn = vm.getReturnOfFunction(source, i)
local doc = docs[i]
- local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ')
- local text = ('%s%s'):format(
+ local name = doc and doc.name and doc.name[1]
+ if name and name ~= '...' then
+ name = name .. ': '
+ end
+ local text = rtn and ('%s%s'):format(
name or '',
- vm.getInfer(rtn):view()
- )
+ vm.getInfer(rtn):view(guide.getUri(source))
+ ) or 'unknown'
if i == 1 then
returns[i] = (' -> %s'):format(text)
else
@@ -83,7 +57,14 @@ local function asDocFunction(source)
end
local returns = {}
for i, rtn in ipairs(source.returns) do
- local rtnText = vm.getInfer(rtn):view()
+ local rtnText = vm.getInfer(rtn):view(guide.getUri(source))
+ if rtn.name then
+ if rtn.name[1] == '...' then
+ rtnText = rtn.name[1] .. rtnText
+ else
+ rtnText = rtn.name[1] .. ': ' .. rtnText
+ end
+ end
if i == 1 then
returns[#returns+1] = (' -> %s'):format(rtnText)
else
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
index 16874101..677fd76c 100644
--- a/script/core/hover/table.lua
+++ b/script/core/hover/table.lua
@@ -30,7 +30,7 @@ local function buildAsHash(uri, keys, nodeMap, reachMax)
node:removeOptional()
end
local ifr = vm.getInfer(node)
- local typeView = ifr:view('unknown', uri)
+ local typeView = ifr:view(uri, 'unknown')
local literalView = ifr:viewLiterals()
if literalView then
lines[#lines+1] = (' %s%s: %s = %s,'):format(
@@ -75,7 +75,7 @@ local function buildAsConst(uri, keys, nodeMap, reachMax)
node = node:copy()
node:removeOptional()
end
- local typeView = vm.getInfer(node):view('unknown', uri)
+ local typeView = vm.getInfer(node):view(uri, 'unknown')
local literalView = literalMap[key]
if literalView then
lines[#lines+1] = (' %s%s: %s = %s,'):format(
@@ -154,7 +154,7 @@ local function getNodeMap(fields, keyMap)
local nodeMap = {}
for _, field in ipairs(fields) do
local key = vm.getKeyName(field)
- if not keyMap[key] then
+ if not key or not keyMap[key] then
goto CONTINUE
end
await.delay()
@@ -178,9 +178,15 @@ return function (source)
return nil
end
- for view in vm.getInfer(source):eachView() do
- if view == 'string'
- or vm.isSubType(uri, view, 'string') then
+ local node = vm.compileNode(source)
+ for n in node:eachObject() do
+ if n.type == 'global' and n.cate == 'type' then
+ if n.name == 'string'
+ or (n.name ~= 'unknown' and n.name ~= 'any' and vm.isSubType(uri, n.name, 'string')) then
+ return nil
+ end
+ elseif n.type == 'doc.type.string'
+ or n.type == 'string' then
return nil
end
end
diff --git a/script/core/jump-source.lua b/script/core/jump-source.lua
new file mode 100644
index 00000000..5ce5e048
--- /dev/null
+++ b/script/core/jump-source.lua
@@ -0,0 +1,62 @@
+local guide = require 'parser.guide'
+local furi = require 'file-uri'
+local ws = require 'workspace'
+
+---@param doc parser.object
+---@return uri
+local function parseUri(doc)
+ local uri
+ local scheme = furi.split(doc.path)
+ if scheme and #scheme >= 2 then
+ uri = doc.path
+ else
+ local suri = guide.getUri(doc):gsub('[/\\][^/\\]*$', '')
+ local path = ws.getAbsolutePath(suri, doc.path)
+ if path then
+ uri = furi.encode(path)
+ else
+ uri = doc.path
+ end
+ end
+ ---@cast uri uri
+ return uri
+end
+
+---@param results table
+return function (results)
+ for _, result in ipairs(results) do
+ if result.target.type == 'doc.field.name'
+ or result.target.type == 'doc.class.name' then
+ local doc = result.target.parent.source
+ if doc then
+ local uri = parseUri(doc)
+ result.uri = uri
+ result.target = {
+ uri = uri,
+ start = guide.positionOf(doc.line - 1, doc.char),
+ finish = guide.positionOf(doc.line - 1, doc.char),
+ }
+ end
+ else
+ local target = result.target
+ if target.type == 'method'
+ or target.type == 'field' then
+ target = target.parent
+ end
+ if target.bindDocs then
+ for _, doc in ipairs(target.bindDocs) do
+ if doc.type == 'doc.source'
+ and doc.bindSource == target then
+ local uri = parseUri(doc)
+ result.uri = uri
+ result.target = {
+ uri = uri,
+ start = guide.positionOf(doc.line - 1, doc.char),
+ finish = guide.positionOf(doc.line - 1, doc.char),
+ }
+ end
+ end
+ end
+ end
+ end
+end
diff --git a/script/core/look-backward.lua b/script/core/look-backward.lua
index eeee6017..8d3e3439 100644
--- a/script/core/look-backward.lua
+++ b/script/core/look-backward.lua
@@ -81,9 +81,19 @@ function m.findTargetSymbol(text, offset, symbol)
return nil
end
-function m.findAnyOffset(text, offset)
+---@param text string
+---@param offset integer
+---@param inline? boolean # 必须在同一行中(排除换行符)
+function m.findAnyOffset(text, offset, inline)
for i = offset, 1, -1 do
- if not m.isSpace(text:sub(i, i)) then
+ local c = text:sub(i, i)
+ if inline then
+ if c == '\r'
+ or c == '\n' then
+ return nil
+ end
+ end
+ if not m.isSpace(c) then
return i
end
end
diff --git a/script/core/reference.lua b/script/core/reference.lua
index 4c9c193d..fa838cff 100644
--- a/script/core/reference.lua
+++ b/script/core/reference.lua
@@ -2,6 +2,7 @@ local guide = require 'parser.guide'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
+local jumpSource = require 'core.jump-source'
local function sortResults(results)
-- 先按照顺序排序
@@ -49,6 +50,7 @@ local accept = {
['doc.class.name'] = true,
['doc.extends.name'] = true,
['doc.alias.name'] = true,
+ ['doc.enum.name'] = true,
}
---@async
@@ -101,12 +103,17 @@ return function (uri, position)
if src.type == 'doc.alias' then
src = src.alias
end
+ if src.type == 'doc.enum' then
+ src = src.enum
+ end
if src.type == 'doc.class.name'
or src.type == 'doc.alias.name'
+ or src.type == 'doc.enum.name'
or src.type == 'doc.type.name'
or src.type == 'doc.extends.name' then
if source.type ~= 'doc.type.name'
and source.type ~= 'doc.class.name'
+ and source.type ~= 'doc.enum.name'
and source.type ~= 'doc.extends.name'
and source.type ~= 'doc.see.name' then
goto CONTINUE
@@ -132,6 +139,7 @@ return function (uri, position)
end
sortResults(results)
+ jumpSource(results)
return results
end
diff --git a/script/core/rename.lua b/script/core/rename.lua
index 7599fad6..90e66224 100644
--- a/script/core/rename.lua
+++ b/script/core/rename.lua
@@ -81,6 +81,9 @@ local function renameField(source, newname, callback)
local uri = guide.getUri(source)
local text = files.getText(uri)
local state = files.getState(uri)
+ if not state or not text then
+ return false
+ end
local func = parent.value
-- function mt:name () end --> mt['newname'] = function (self) end
local startOffset = guide.positionToOffset(state, parent.start) + 1
@@ -183,13 +186,16 @@ local function ofField(source, newname, callback)
local key = guide.getKeyName(source)
local refs = vm.getRefs(source)
for _, ref in ipairs(refs) do
- ofFieldThen(key, ref, newname, callback)
+ ofFieldThen(key, ref, newname, callback)
end
end
---@async
local function ofGlobal(source, newname, callback)
local key = guide.getKeyName(source)
+ if not key then
+ return
+ end
local global = vm.getGlobal('variable', key)
if not global then
return
@@ -225,6 +231,9 @@ local function ofDocTypeName(source, newname, callback)
if doc.type == 'doc.alias' then
callback(doc, doc.alias.start, doc.alias.finish, newname)
end
+ if doc.type == 'doc.enum' then
+ callback(doc, doc.enum.start, doc.enum.finish, newname)
+ end
end
for _, doc in ipairs(global:getGets(uri)) do
if doc.type == 'doc.type.name' then
@@ -236,16 +245,15 @@ end
local function ofDocParamName(source, newname, callback)
callback(source, source.start, source.finish, newname)
local doc = source.parent
- if doc.bindSources then
- for _, src in ipairs(doc.bindSources) do
- if src.type == 'local'
- and src.parent.type == 'funcargs'
- and src[1] == source[1] then
- renameLocal(src, newname, callback)
- if src.ref then
- for _, ref in ipairs(src.ref) do
- renameLocal(ref, newname, callback)
- end
+ local src = doc.bindSource
+ if src then
+ if src.type == 'local'
+ and src.parent.type == 'funcargs'
+ and src[1] == source[1] then
+ renameLocal(src, newname, callback)
+ if src.ref then
+ for _, ref in ipairs(src.ref) do
+ renameLocal(ref, newname, callback)
end
end
end
@@ -271,7 +279,8 @@ local function rename(source, newname, callback)
return ofGlobal(source, newname, callback)
elseif source.type == 'doc.class.name'
or source.type == 'doc.type.name'
- or source.type == 'doc.alias.name' then
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.enum.name' then
return ofDocTypeName(source, newname, callback)
elseif source.type == 'doc.param.name' then
return ofDocParamName(source, newname, callback)
@@ -305,6 +314,7 @@ local function prepareRename(source)
or source.type == 'doc.class.name'
or source.type == 'doc.type.name'
or source.type == 'doc.alias.name'
+ or source.type == 'doc.enum.name'
or source.type == 'doc.param.name' then
return source, source[1]
elseif source.type == 'string'
@@ -345,6 +355,7 @@ local accept = {
['doc.type.name'] = true,
['doc.alias.name'] = true,
['doc.param.name'] = true,
+ ['doc.enum.name'] = true,
}
local m = {}
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index 33449013..5833807b 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -32,7 +32,7 @@ local Care = util.switch()
end
options.libGlobals[name] = isLib
end
- local isFunc = vm.getInfer(source):hasFunction()
+ local isFunc = vm.getInfer(source):hasFunction(guide.getUri(source))
local type = isFunc and define.TokenTypes['function'] or define.TokenTypes.variable
local modifier = isLib and define.TokenModifiers.defaultLibrary or define.TokenModifiers.static
@@ -81,7 +81,7 @@ local Care = util.switch()
return
end
end
- if vm.getInfer(source):hasFunction() then
+ if vm.getInfer(source):hasFunction(guide.getUri(source)) then
results[#results+1] = {
start = source.start,
finish = source.finish,
@@ -134,19 +134,16 @@ local Care = util.switch()
return
end
local loc = source.node or source
+ local uri = guide.getUri(loc)
-- 1. 值为函数的局部变量 | Local variable whose value is a function
- if loc.ref then
- for _, ref in ipairs(loc.ref) do
- if ref.value and ref.value.type == 'function' then
- results[#results+1] = {
- start = source.start,
- finish = source.finish,
- type = define.TokenTypes['function'],
- modifieres = define.TokenModifiers.declaration,
- }
- return
- end
- end
+ if vm.getInfer(source):hasFunction(uri) then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes['function'],
+ modifieres = define.TokenModifiers.declaration,
+ }
+ return
end
-- 3. 特殊变量 | Special variableif source[1] == '_ENV' then
if loc[1] == '_ENV' then
@@ -196,7 +193,7 @@ local Care = util.switch()
end
end
-- 6. References to other functions
- if vm.getInfer(loc):hasFunction() then
+ if vm.getInfer(loc):hasFunction(guide.getUri(source)) then
results[#results+1] = {
start = source.start,
finish = source.finish,
@@ -449,6 +446,7 @@ local Care = util.switch()
end
end)
: case 'doc.alias.name'
+ : case 'doc.enum.name'
: call(function (source, options, results)
if not options.annotation then
return
@@ -667,6 +665,14 @@ local Care = util.switch()
type = define.TokenTypes.keyword,
}
end)
+ : case 'doc.cast.block'
+ : call(function (source, options, results)
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.keyword,
+ }
+ end)
: case 'doc.cast.name'
: call(function (source, options, results)
results[#results+1] = {
@@ -675,6 +681,23 @@ local Care = util.switch()
type = define.TokenTypes.variable,
}
end)
+ : case 'doc.type.code'
+ : call(function (source, options, results)
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.string,
+ modifieres = define.TokenModifiers.abstract,
+ }
+ end)
+ : case 'doc.operator.name'
+ : call(function (source, options, results)
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.operator,
+ }
+ end)
local function buildTokens(uri, results)
local tokens = {}
@@ -811,9 +834,13 @@ return function (uri, start, finish)
keyword = config.get(uri, 'Lua.semantic.keyword'),
}
+ local n = 0
guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async
Care(source.type, source, options, results)
- await.delay()
+ n = n + 1
+ if n % 100 == 0 then
+ await.delay()
+ end
end)
for _, comm in ipairs(state.comms) do
diff --git a/script/core/signature.lua b/script/core/signature.lua
index 025e70b7..21e954bf 100644
--- a/script/core/signature.lua
+++ b/script/core/signature.lua
@@ -8,6 +8,9 @@ local lookback = require 'core.look-backward'
local function findNearCall(uri, ast, pos)
local text = files.getText(uri)
local state = files.getState(uri)
+ if not state or not text then
+ return nil
+ end
local nearCall
guide.eachSourceContain(ast.ast, pos, function (src)
if src.type == 'call'
@@ -65,27 +68,30 @@ local function makeOneSignature(source, oop, index)
}
end
-- 不定参数
- if index > i and i > 0 then
+ if index and index > i and i > 0 then
local lastLabel = params[i].label
local text = label:sub(lastLabel[1] + 1, lastLabel[2])
if text:sub(1, 3) == '...' then
index = i
end
end
+ if #params < (index or 0) then
+ return nil
+ end
return {
label = label,
params = params,
- index = index,
+ index = index or 1,
description = hoverDesc(source),
}
end
---@async
local function makeSignatures(text, call, pos)
- local node = call.node
- local oop = node.type == 'method'
- or node.type == 'getmethod'
- or node.type == 'setmethod'
+ local func = call.node
+ local oop = func.type == 'method'
+ or func.type == 'getmethod'
+ or func.type == 'setmethod'
local index
if call.args then
local args = {}
@@ -121,13 +127,13 @@ local function makeSignatures(text, call, pos)
index = #args
end
end
- else
- index = 1
end
local signs = {}
- local defs = vm.getDefs(node)
+ local node = vm.compileNode(func)
+ ---@type vm.node
+ node = node:getData 'originNode' or node
local mark = {}
- for _, src in ipairs(defs) do
+ for src in node:eachObject() do
if src.type == 'function'
or src.type == 'doc.type.function' then
if not mark[src] then
@@ -142,10 +148,10 @@ end
---@async
return function (uri, pos)
local state = files.getState(uri)
- if not state then
+ local text = files.getText(uri)
+ if not state or not text then
return nil
end
- local text = files.getText(uri)
local offset = guide.positionToOffset(state, pos)
pos = guide.offsetToPosition(state, lookback.skipSpace(text, offset))
local call = findNearCall(uri, state, pos)
@@ -156,5 +162,8 @@ return function (uri, pos)
if not signs or #signs == 0 then
return nil
end
+ table.sort(signs, function (a, b)
+ return #a.params < #b.params
+ end)
return signs
end
diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua
index d8434c8c..a1c2b29f 100644
--- a/script/core/type-definition.lua
+++ b/script/core/type-definition.lua
@@ -4,6 +4,7 @@ local vm = require 'vm'
local findSource = require 'core.find-source'
local guide = require 'parser.guide'
local rpath = require 'workspace.require-path'
+local jumpSource = require 'core.jump-source'
local function sortResults(results)
-- 先按照顺序排序
@@ -51,6 +52,7 @@ local accept = {
['doc.class.name'] = true,
['doc.extends.name'] = true,
['doc.alias.name'] = true,
+ ['doc.enum.name'] = true,
['doc.see.name'] = true,
['doc.see.field'] = true,
}
@@ -74,7 +76,7 @@ local function checkRequire(source, offset)
return nil
end
if libName == 'require' then
- return rpath.findUrisByRequirePath(guide.getUri(source), literal)
+ return rpath.findUrisByRequireName(guide.getUri(source), literal)
elseif libName == 'dofile'
or libName == 'loadfile' then
return workspace.findUrisByFilePath(literal)
@@ -144,6 +146,9 @@ return function (uri, offset)
if src.type == 'doc.alias' then
src = src.alias
end
+ if src.type == 'doc.enum' then
+ src = src.enum
+ end
if src.type == 'doc.class.name'
or src.type == 'doc.alias.name'
or src.type == 'doc.type.function'
@@ -164,6 +169,7 @@ return function (uri, offset)
end
sortResults(results)
+ jumpSource(results)
return results
end
diff --git a/script/doctor.lua b/script/doctor.lua
index 87cdcfcb..e1044689 100644
--- a/script/doctor.lua
+++ b/script/doctor.lua
@@ -538,7 +538,7 @@ m.exclude = private(function (...)
end)
--- 比较2个报告
----@return string
+---@return table
m.compare = private(function (old, new)
local newHash = {}
local ret = {}
diff --git a/script/file-uri.lua b/script/file-uri.lua
index ccd47156..3e916acf 100644
--- a/script/file-uri.lua
+++ b/script/file-uri.lua
@@ -24,9 +24,6 @@ local m = {}
---@param path string
---@return uri uri
function m.encode(path)
- if not path then
- return nil
- end
local authority = ''
if platform.OS == 'Windows' then
path = path:gsub('\\', '/')
@@ -70,9 +67,6 @@ end
---@param uri uri
---@return string path
function m.decode(uri)
- if not uri then
- return nil
- end
local scheme, authority, path = uri:match('([^:]*):?/?/?([^/]*)(.*)')
if not scheme then
return ''
@@ -95,7 +89,27 @@ function m.decode(uri)
end
function m.split(uri)
- return uri:match('([^:]*):?/?/?([^/]*)(.*)')
+ return uri:match('([^:]*):/?/?([^/]*)(.*)')
+end
+
+---@param uri string
+---@return boolean
+function m.isValid(uri)
+ local scheme, authority, path = m.split(uri)
+ if not scheme or scheme == '' then
+ return false
+ end
+ if path == '' then
+ return false
+ end
+ return true
+end
+
+function m.normalize(uri)
+ if uri == '' then
+ return uri
+ end
+ return m.encode(m.decode(uri))
end
return m
diff --git a/script/files.lua b/script/files.lua
index 22c9ae31..91bbe570 100644
--- a/script/files.lua
+++ b/script/files.lua
@@ -14,16 +14,30 @@ local progress = require "progress"
local encoder = require 'encoder'
local scope = require 'workspace.scope'
+---@class file
+---@field uri uri
+---@field content string
+---@field _ref? integer
+---@field trusted? boolean
+---@field rows? integer[]
+---@field originText? string
+---@field text string
+---@field version? integer
+---@field originLines? integer[]
+---@field state? parser.state
+---@field _diffInfo? table[]
+---@field cache table
+
---@class files
local m = {}
m.watchList = {}
m.notifyCache = {}
m.assocVersion = -1
-m.assocMatcher = nil
function m.reset()
m.openMap = {}
+ ---@type table<string, file>
m.fileMap = {}
m.dllMap = {}
m.visible = {}
@@ -41,12 +55,14 @@ local uriMap = {}
---@return uri
function m.getRealUri(uri)
local filename = furi.decode(uri)
+ -- normalize uri
+ uri = furi.encode(filename)
local path = fs.path(filename)
- local suc, res = pcall(fs.exists, path)
- if not suc or not res then
+ local suc, exists = pcall(fs.exists, path)
+ if not suc or not exists then
return uri
end
- suc, res = pcall(fs.canonical, path)
+ local suc, res = pcall(fs.canonical, path)
if not suc then
return uri
end
@@ -123,6 +139,7 @@ function m.isLibrary(uri, excludeFolder)
end
--- 获取库文件的根目录
+---@return uri?
function m.getLibraryUri(suri, uri)
local scp = scope.getScope(suri)
return scp:getLinkedUri(uri)
@@ -134,6 +151,9 @@ function m.exists(uri)
return m.fileMap[uri] ~= nil
end
+---@param file file
+---@param text string
+---@return string
local function pluginOnSetText(file, text)
local plugin = require 'plugin'
file._diffInfo = nil
@@ -164,7 +184,7 @@ end
--- 设置文件文本
---@param uri uri
----@param text string
+---@param text? string
---@param isTrust? boolean
---@param callback? function
function m.setText(uri, text, isTrust, callback)
@@ -320,7 +340,7 @@ end
--- 获取文件文本
---@param uri uri
----@return string text
+---@return string? text
function m.getText(uri)
local file = m.fileMap[uri]
if not file then
@@ -331,7 +351,7 @@ end
--- 获取文件原始文本
---@param uri uri
----@return string text
+---@return string? text
function m.getOriginText(uri)
local file = m.fileMap[uri]
if not file then
@@ -345,9 +365,7 @@ end
---@return integer[]
function m.getOriginLines(uri)
local file = m.fileMap[uri]
- if not file then
- return nil
- end
+ assert(file, 'file not exists:' .. uri)
return file.originLines
end
@@ -444,7 +462,6 @@ function m.eachFile(suri)
end
--- Pairs dll files
----@return function
function m.eachDll()
local map = {}
for uri, file in pairs(m.dllMap) do
@@ -488,7 +505,7 @@ function m.compileState(uri, text)
, {
special = config.get(uri, 'Lua.runtime.special'),
unicodeName = config.get(uri, 'Lua.runtime.unicodeName'),
- nonstandardSymbol = config.get(uri, 'Lua.runtime.nonstandardSymbol'),
+ nonstandardSymbol = util.arrayToHash(config.get(uri, 'Lua.runtime.nonstandardSymbol')),
}
)
local passed = os.clock() - clock
@@ -524,7 +541,7 @@ end
--- 获取文件语法树
---@param uri uri
----@return table state
+---@return table? state
function m.getState(uri)
local file = m.fileMap[uri]
if not file then
@@ -534,7 +551,7 @@ function m.getState(uri)
if not state then
state = m.compileState(uri, file.text)
m.astMap[uri] = state
- file.ast = state
+ file.state = state
--await.delay()
end
file.cacheActiveTime = timer.clock()
@@ -546,7 +563,7 @@ function m.getLastState(uri)
if not file then
return nil
end
- return file.ast
+ return file.state
end
function m.getFile(uri)
diff --git a/script/filewatch.lua b/script/filewatch.lua
index 5e3a0322..ecc72255 100644
--- a/script/filewatch.lua
+++ b/script/filewatch.lua
@@ -5,13 +5,13 @@ local await = require 'await'
local MODIFY = 1 << 0
local RENAME = 1 << 1
-local function exists(filename)
+local function isExists(filename)
local path = fs.path(filename)
- local suc, res = pcall(fs.exists, path)
- if not suc or not res then
+ local suc, exists = pcall(fs.exists, path)
+ if not suc or not exists then
return false
end
- suc, res = pcall(fs.canonical, path)
+ local suc, res = pcall(fs.canonical, path)
if not suc or res:string() ~= path:string() then
return false
end
@@ -69,6 +69,7 @@ function m.update()
if not ev then
break
end
+ log.debug('filewatch:', ev, path)
if not collect then
collect = {}
end
@@ -85,7 +86,7 @@ function m.update()
for path, flag in pairs(collect) do
if flag & RENAME ~= 0 then
- if exists(path) then
+ if isExists(path) then
m._callEvent('create', path)
else
m._callEvent('delete', path)
diff --git a/script/fs-utility.lua b/script/fs-utility.lua
index 08aae98a..b789177c 100644
--- a/script/fs-utility.lua
+++ b/script/fs-utility.lua
@@ -13,9 +13,10 @@ local tableSort = table.sort
_ENV = nil
+---@class fs-utility
local m = {}
--- 读取文件
----@param path string
+---@param path string|fs.path
function m.loadFile(path, keepBom)
if type(path) ~= 'string' then
---@diagnostic disable-next-line: undefined-field
@@ -40,7 +41,7 @@ function m.loadFile(path, keepBom)
end
--- 写入文件
----@param path string
+---@param path any
---@param content string
function m.saveFile(path, content)
if type(path) ~= 'string' then
@@ -255,6 +256,9 @@ function dfs:saveFile(path, text)
dir[filename] = text
end
+---@param path string|fs.path
+---@param option table
+---@return fs.path?
local function fsAbsolute(path, option)
if type(path) == 'string' then
local suc, res = pcall(fs.path, path)
@@ -444,6 +448,9 @@ local function fileRemove(path, option)
end
end
+---@param source fs.path?
+---@param target fs.path?
+---@param option table
local function fileCopy(source, target, option)
if not source or not target then
return
@@ -477,6 +484,9 @@ local function fileCopy(source, target, option)
end
end
+---@param source fs.path?
+---@param target fs.path?
+---@param option table
local function fileSync(source, target, option)
if not source or not target then
return
@@ -583,29 +593,29 @@ function m.fileRemove(path, option)
end
--- 复制文件(夹)
----@param source string
----@param target string
+---@param source string|fs.path
+---@param target string|fs.path
---@return table
function m.fileCopy(source, target, option)
option = buildOption(option)
- source = fsAbsolute(source, option)
- target = fsAbsolute(target, option)
+ local fsSource = fsAbsolute(source, option)
+ local fsTarget = fsAbsolute(target, option)
- fileCopy(source, target, option)
+ fileCopy(fsSource, fsTarget, option)
return option
end
--- 同步文件(夹)
----@param source string
----@param target string
+---@param source string|fs.path
+---@param target string|fs.path
---@return table
function m.fileSync(source, target, option)
option = buildOption(option)
- source = fsAbsolute(source, option)
- target = fsAbsolute(target, option)
+ local fsSource = fsAbsolute(source, option)
+ local fsTarget = fsAbsolute(target, option)
- fileSync(source, target, option)
+ fileSync(fsSource, fsTarget, option)
return option
end
diff --git a/script/glob/gitignore.lua b/script/glob/gitignore.lua
index 4dad2747..de8fd005 100644
--- a/script/glob/gitignore.lua
+++ b/script/glob/gitignore.lua
@@ -164,12 +164,10 @@ function mt:getRelativePath(path)
end
---@param callback async fun(path: string)
+---@param hook? async fun(ev: string, ...)
---@async
-function mt:scan(path, callback)
+function mt:scan(path, callback, hook)
local files = {}
- if type(callback) ~= 'function' then
- callback = nil
- end
local list = {}
---@async
@@ -203,6 +201,9 @@ function mt:scan(path, callback)
break
end
list[#list] = nil
+ if hook then
+ hook('scan', current)
+ end
if not self:simpleMatch(current) then
check(current)
end
diff --git a/script/jsonrpc.lua b/script/jsonrpc.lua
index 91d6c9dd..7411fee8 100644
--- a/script/jsonrpc.lua
+++ b/script/jsonrpc.lua
@@ -50,6 +50,7 @@ function m.decode(reader)
if not content then
return nil, 'Proto read error'
end
+ ---@type any
local null = json.null
json.null = nil
local suc, res = pcall(json.decode, content)
diff --git a/script/lclient.lua b/script/lclient.lua
index ad1fff3d..e1504e61 100644
--- a/script/lclient.lua
+++ b/script/lclient.lua
@@ -80,7 +80,6 @@ function mt:reportHangs()
end
---@param callback async fun(client: languageClient)
----@return languageClient
function mt:start(callback)
CLI = true
@@ -208,8 +207,10 @@ function mt:registerFakers()
'textDocument/publishDiagnostics',
'workspace/configuration',
'workspace/semanticTokens/refresh',
+ 'workspace/diagnostic/refresh',
'window/workDoneProgress/create',
'window/showMessage',
+ 'window/showMessageRequest',
'window/logMessage',
} do
self:register(method, function ()
diff --git a/script/library.lua b/script/library.lua
index 66c4d364..57aac066 100644
--- a/script/library.lua
+++ b/script/library.lua
@@ -209,6 +209,7 @@ local function initBuiltIn(uri)
local langID = lang.id
local version = config.get(uri, 'Lua.runtime.version')
local encoding = config.get(uri, 'Lua.runtime.fileEncoding')
+ ---@type fs.path
local metaPath = fs.path(METAPATH) / config.get(uri, 'Lua.runtime.meta'):gsub('%$%{(.-)%}', {
version = version,
language = langID,
@@ -243,6 +244,7 @@ local function initBuiltIn(uri)
goto CONTINUE
end
libName = libName .. '.lua'
+ ---@type fs.path
local libPath = templateDir / libName
local metaDoc = compileSingleMetaDoc(uri, fsu.loadFile(libPath), metaLang, status)
if metaDoc then
@@ -260,6 +262,7 @@ local function initBuiltIn(uri)
end
end
+---@param libraryDir fs.path
local function loadSingle3rdConfig(libraryDir)
local configText = fsu.loadFile(libraryDir / 'config.lua')
if not configText then
@@ -321,7 +324,13 @@ local function load3rdConfigInDir(dir, configs, inner)
end
local function load3rdConfig(uri)
- local configs = {}
+ local scp = scope.getScope(uri)
+ local configs = scp:get 'thirdConfigsCache'
+ if configs then
+ return configs
+ end
+ configs = {}
+ scp:set('thirdConfigsCache', configs)
load3rdConfigInDir(innerThirdDir, configs, true)
local thirdDirs = config.get(uri, 'Lua.workspace.userThirdParty')
for _, thirdDir in ipairs(thirdDirs) do
@@ -400,6 +409,15 @@ local function askFor3rd(uri, cfg)
uri = uri,
},
}, true)
+ else
+ client.setConfig({
+ {
+ key = 'Lua.workspace.checkThirdParty',
+ action = 'set',
+ value = false,
+ uri = uri,
+ },
+ }, false)
end
end
@@ -420,11 +438,21 @@ local function wholeMatch(a, b)
return true
end
-local function check3rdByWords(uri, text, configs)
+local function check3rdByWords(uri, configs)
if hasAsked then
return
end
+ if not files.isLua(uri) then
+ return
+ end
+ local id = 'check3rdByWords:' .. uri
+ await.close(id)
await.call(function () ---@async
+ await.sleep(0.1)
+ local text = files.getText(uri)
+ if not text then
+ return
+ end
for _, cfg in ipairs(configs) do
if cfg.words then
for _, word in ipairs(cfg.words) do
@@ -436,7 +464,7 @@ local function check3rdByWords(uri, text, configs)
end
end
end
- end)
+ end, id)
end
local function check3rdByFileName(uri, configs)
@@ -447,7 +475,10 @@ local function check3rdByFileName(uri, configs)
if not path then
return
end
+ local id = 'check3rdByFileName:' .. uri
+ await.close(id)
await.call(function () ---@async
+ await.sleep(0.1)
for _, cfg in ipairs(configs) do
if cfg.files then
for _, filename in ipairs(cfg.files) do
@@ -459,50 +490,62 @@ local function check3rdByFileName(uri, configs)
end
end
end
- end)
-end
-
-local lastCheckedUri = {}
-local function checkedUri(uri)
- if lastCheckedUri[uri]
- and timer.clock() - lastCheckedUri[uri] < 5 then
- return false
- end
- lastCheckedUri[uri] = timer.clock()
- return true
+ end, id)
end
-local thirdConfigs
+---@async
local function check3rd(uri)
if hasAsked then
return
end
+ if ws.isIgnored(uri) then
+ return
+ end
if not config.get(uri, 'Lua.workspace.checkThirdParty') then
return
end
- if thirdConfigs == nil then
- thirdConfigs = load3rdConfig(uri) or false
+ local scp = scope.getScope(uri)
+ if not scp:get 'canCheckThirdParty' then
+ return
end
+ local thirdConfigs = load3rdConfig(uri) or false
if not thirdConfigs then
return
end
- if checkedUri(uri) then
- if files.isLua(uri) then
- local text = files.getText(uri)
- if text then
- check3rdByWords(uri, text, thirdConfigs)
- end
+ check3rdByWords(uri, thirdConfigs)
+ check3rdByFileName(uri, thirdConfigs)
+end
+
+local function check3rdOfWorkspace(suri)
+ local scp = scope.getScope(suri)
+ scp:set('thirdConfigsCache', nil)
+ scp:set('canCheckThirdParty', true)
+ local id = 'check3rdOfWorkspace:' .. scp:getName()
+ await.close(id)
+ ---@async
+ await.call(function ()
+ ws.awaitReady(suri)
+ for uri in files.eachFile(suri) do
+ check3rd(uri)
end
- check3rdByFileName(uri, thirdConfigs)
- end
+ for uri in files.eachDll() do
+ check3rd(uri)
+ end
+ end, id)
end
config.watch(function (uri, key, value, oldValue)
if key:find '^Lua.runtime' then
initBuiltIn(uri)
end
+ if key == 'Lua.workspace.checkThirdParty'
+ or key == 'Lua.workspace.userThirdParty'
+ or key == '' then
+ check3rdOfWorkspace(uri)
+ end
end)
+---@async
files.watch(function (ev, uri)
if ev == 'update'
or ev == 'dll' then
@@ -510,11 +553,10 @@ files.watch(function (ev, uri)
end
end)
-function m.init()
- initBuiltIn(nil)
- for _, scp in ipairs(ws.folders) do
- initBuiltIn(scp.uri)
+ws.watch(function (ev, uri)
+ if ev == 'startReload' then
+ initBuiltIn(uri)
end
-end
+end)
return m
diff --git a/script/linked-table.lua b/script/linked-table.lua
index 4d87e943..a63a528c 100644
--- a/script/linked-table.lua
+++ b/script/linked-table.lua
@@ -8,10 +8,14 @@ mt._size = 0
local HEAD = {'<HEAD>'}
local TAIL = {'<TAIL>'}
+---@param node any
+---@return boolean
function mt:has(node)
return self._left[node] ~= nil
end
+---@param node any
+---@return boolean
function mt:isValidNode(node)
if node == nil
or node == HEAD
@@ -21,6 +25,9 @@ function mt:isValidNode(node)
return true
end
+---@param node any
+---@param afterWho any
+---@return boolean
function mt:pushAfter(node, afterWho)
if not self:isValidNode(node) then
return false
@@ -41,6 +48,9 @@ function mt:pushAfter(node, afterWho)
return true
end
+---@param node any
+---@param beforeWho any
+---@return boolean
function mt:pushBefore(node, beforeWho)
if node == nil then
return false
@@ -52,6 +62,8 @@ function mt:pushBefore(node, beforeWho)
return self:pushAfter(node, left)
end
+---@param node any
+---@return boolean
function mt:pop(node)
if not self:isValidNode(node) then
return false
@@ -71,14 +83,20 @@ function mt:pop(node)
return true
end
+---@param node any
+---@return boolean
function mt:pushHead(node)
return self:pushAfter(node, HEAD)
end
+---@param node any
+---@return boolean
function mt:pushTail(node)
return self:pushBefore(node, TAIL)
end
+---@param node any
+---@return any
function mt:getAfter(node)
if node == nil then
node = HEAD
@@ -90,10 +108,12 @@ function mt:getAfter(node)
return right
end
+---@return any
function mt:getHead()
return self:getAfter(HEAD)
end
+---@return any
function mt:getBefore(node)
if node == nil then
node = TAIL
@@ -105,18 +125,24 @@ function mt:getBefore(node)
return left
end
+---@return any
function mt:getTail()
return self:getBefore(TAIL)
end
+---@return boolean
function mt:popHead()
return self:pop(self:getHead())
end
+---@return boolean
function mt:popTail()
return self:pop(self:getTail())
end
+---@param old any
+---@param new any
+---@return boolean
function mt:replace(old, new)
if not self:isValidNode(old)
or not self:isValidNode(new) then
@@ -137,10 +163,14 @@ function mt:replace(old, new)
return true
end
+---@return integer
function mt:getSize()
return self._size
end
+---@param start any
+---@param revert? boolean
+---@return fun():any
function mt:pairs(start, revert)
if revert then
if start == nil then
@@ -171,6 +201,9 @@ function mt:pairs(start, revert)
end
end
+---@param start any
+---@param revert? boolean
+---@return string
function mt:dump(start, revert)
local t = {}
for node in self:pairs(start, revert) do
@@ -186,6 +219,7 @@ function mt:reset()
self._size = 0
end
+---@return linked-table
return function ()
local self = setmetatable({}, mt)
self:reset()
diff --git a/script/meta/bee/filesystem.lua b/script/meta/bee/filesystem.lua
new file mode 100644
index 00000000..f6cdff79
--- /dev/null
+++ b/script/meta/bee/filesystem.lua
@@ -0,0 +1,91 @@
+---@class fs.path
+---@operator div: fs.path
+local fsPath = {}
+
+---@return string
+function fsPath:string()
+end
+
+---@return fs.path
+function fsPath:parent_path()
+end
+
+---@return boolean
+function fsPath:is_relative()
+end
+
+---@return fs.path
+function fsPath:filename()
+end
+
+---@class fs.status
+local fsStatus = {}
+
+---@return string
+function fsStatus:type()
+end
+
+---@class fs
+local fs = {}
+
+---@class fs.copy_options
+---@field overwrite_existing integer
+local copy_options
+
+fs.copy_options = copy_options
+
+---@param path string
+---@return fs.path
+function fs.path(path)
+end
+
+---@return fs.path
+function fs.exe_path()
+end
+
+---@param path fs.path
+---@return boolean
+function fs.exists(path)
+end
+
+---@param path fs.path
+---@return boolean
+function fs.is_directory(path)
+end
+
+---@param path fs.path
+---@return fun():fs.path
+function fs.pairs(path)
+end
+
+---@param path fs.path
+---@return fs.path
+function fs.canonical(path)
+end
+
+---@param path fs.path
+---@return fs.path
+function fs.absolute(path)
+end
+
+---@param path fs.path
+function fs.create_directories(path)
+end
+
+---@param path fs.path
+---@return fs.status
+function fs.symlink_status(path)
+end
+
+---@param path fs.path
+---@return boolean
+function fs.remove(path)
+end
+
+---@param source fs.path
+---@param target fs.path
+---@param options? `fs.copy_options.overwrite_existing`
+function fs.copy_file(source, target, options)
+end
+
+return fs
diff --git a/script/parser/ast.lua b/script/parser/ast.lua
deleted file mode 100644
index 648a6890..00000000
--- a/script/parser/ast.lua
+++ /dev/null
@@ -1,1997 +0,0 @@
-local tonumber = tonumber
-local stringChar = string.char
-local utf8Char = utf8.char
-local tableUnpack = table.unpack
-local mathType = math.type
-local tableRemove = table.remove
-local tableSort = table.sort
-local print = print
-local tostring = tostring
-
-_ENV = nil
-
-local DefaultState = {
- lua = '',
- options = {},
-}
-
-local State = DefaultState
-local PushError
-local PushDiag
-local PushComment
-
--- goto 单独处理
-local RESERVED = {
- ['and'] = true,
- ['break'] = true,
- ['do'] = true,
- ['else'] = true,
- ['elseif'] = true,
- ['end'] = true,
- ['false'] = true,
- ['for'] = true,
- ['function'] = true,
- ['if'] = true,
- ['in'] = true,
- ['local'] = true,
- ['nil'] = true,
- ['not'] = true,
- ['or'] = true,
- ['repeat'] = true,
- ['return'] = true,
- ['then'] = true,
- ['true'] = true,
- ['until'] = true,
- ['while'] = true,
-}
-
-local VersionOp = {
- ['&'] = {'Lua 5.3', 'Lua 5.4'},
- ['~'] = {'Lua 5.3', 'Lua 5.4'},
- ['|'] = {'Lua 5.3', 'Lua 5.4'},
- ['<<'] = {'Lua 5.3', 'Lua 5.4'},
- ['>>'] = {'Lua 5.3', 'Lua 5.4'},
- ['//'] = {'Lua 5.3', 'Lua 5.4'},
-}
-
-local SymbolAlias = {
- ['||'] = 'or',
- ['&&'] = 'and',
- ['!='] = '~=',
- ['!'] = 'not',
-}
-
-local function checkOpVersion(op)
- local versions = VersionOp[op.type]
- if not versions then
- return
- end
- for i = 1, #versions do
- if versions[i] == State.version then
- return
- end
- end
- PushError {
- type = 'UNSUPPORT_SYMBOL',
- start = op.start,
- finish = op.finish,
- version = versions,
- info = {
- version = State.version,
- }
- }
-end
-
-local function checkMissEnd(start)
- if not State.MissEndErr then
- return
- end
- local err = State.MissEndErr
- State.MissEndErr = nil
- local _, finish = State.lua:find('[%w_]+', start)
- if not finish then
- return
- end
- err.info.related = {
- {
- start = start,
- finish = finish,
- }
- }
- PushError {
- type = 'MISS_END',
- start = start,
- finish = finish,
- }
-end
-
-local function getSelect(vararg, index)
- return {
- type = 'select',
- start = vararg.start,
- finish = vararg.finish,
- vararg = vararg,
- sindex = index,
- }
-end
-
-local function getValue(values, i)
- if not values then
- return nil, nil
- end
- local value = values[i]
- if not value then
- local last = values[#values]
- if not last then
- return nil, nil
- end
- if last.type == 'call' or last.type == 'varargs' then
- return getSelect(last, i - #values + 1)
- end
- return nil, nil
- end
- if value.type == 'call' or value.type == 'varargs' then
- value = getSelect(value, 1)
- end
- return value
-end
-
-local function createLocal(key, effect, value, attrs)
- if not key then
- return nil
- end
- key.type = 'local'
- key.effect = effect
- key.value = value
- key.attrs = attrs
- if value then
- key.range = value.finish
- end
- return key
-end
-
-local function createCall(args, start, finish)
- if args then
- args.type = 'callargs'
- args.start = start
- args.finish = finish
- end
- return {
- type = 'call',
- start = start,
- finish = finish,
- args = args,
- }
-end
-
-local function packList(start, list, finish)
- local lastFinish = start
- local wantName = true
- local count = 0
- for i = 1, #list do
- local ast = list[i]
- if ast.type == ',' then
- if wantName or i == #list then
- PushError {
- type = 'UNEXPECT_SYMBOL',
- start = ast.start,
- finish = ast.finish,
- info = {
- symbol = ',',
- }
- }
- end
- wantName = true
- else
- if not wantName then
- PushError {
- type = 'MISS_SYMBOL',
- start = lastFinish,
- finish = ast.start - 1,
- info = {
- symbol = ',',
- }
- }
- end
- wantName = false
- count = count + 1
- list[count] = list[i]
- end
- lastFinish = ast.finish + 1
- end
- for i = count + 1, #list do
- list[i] = nil
- end
- list.type = 'list'
- list.start = start
- list.finish = finish - 1
- return list
-end
-
-local BinaryLevel = {
- ['or'] = 1,
- ['and'] = 2,
- ['<='] = 3,
- ['>='] = 3,
- ['<'] = 3,
- ['>'] = 3,
- ['~='] = 3,
- ['=='] = 3,
- ['|'] = 4,
- ['~'] = 5,
- ['&'] = 6,
- ['<<'] = 7,
- ['>>'] = 7,
- ['..'] = 8,
- ['+'] = 9,
- ['-'] = 9,
- ['*'] = 10,
- ['//'] = 10,
- ['/'] = 10,
- ['%'] = 10,
- ['^'] = 11,
-}
-
-local BinaryForward = {
- [01] = true,
- [02] = true,
- [03] = true,
- [04] = true,
- [05] = true,
- [06] = true,
- [07] = true,
- [08] = false,
- [09] = true,
- [10] = true,
- [11] = false,
-}
-
-local Defs = {
- Nil = function (pos)
- return {
- type = 'nil',
- start = pos,
- finish = pos + 2,
- }
- end,
- True = function (pos)
- return {
- type = 'boolean',
- start = pos,
- finish = pos + 3,
- [1] = true,
- }
- end,
- False = function (pos)
- return {
- type = 'boolean',
- start = pos,
- finish = pos + 4,
- [1] = false,
- }
- end,
- ShortComment = function (start, text, finish)
- PushComment {
- type = 'comment.short',
- start = start,
- finish = finish - 1,
- text = text,
- }
- end,
- LongComment = function (start, beforeEq, afterEq, str, close, finish)
- PushComment {
- type = 'comment.long',
- start = start,
- finish = finish - 1,
- text = str,
- }
- if not close then
- local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']'
- local s, _, w = str:find('(%][%=]*%])[%c%s]*$')
- if s then
- PushError {
- type = 'ERR_LCOMMENT_END',
- start = finish - #str + s - 1,
- finish = finish - #str + s + #w - 2,
- info = {
- symbol = endSymbol,
- },
- fix = {
- title = 'FIX_LCOMMENT_END',
- {
- start = finish - #str + s - 1,
- finish = finish - #str + s + #w - 2,
- text = endSymbol,
- }
- },
- }
- end
- PushError {
- type = 'MISS_SYMBOL',
- start = finish,
- finish = finish,
- info = {
- symbol = endSymbol,
- },
- fix = {
- title = 'ADD_LCOMMENT_END',
- {
- start = finish,
- finish = finish,
- text = endSymbol,
- }
- },
- }
- end
- end,
- CLongComment = function (start1, finish1, str, start2, finish2)
- if State.options.nonstandardSymbol and State.options.nonstandardSymbol['/**/'] then
- else
- PushError {
- type = 'ERR_C_LONG_COMMENT',
- start = start1,
- finish = finish2 - 1,
- fix = {
- title = 'FIX_C_LONG_COMMENT',
- {
- start = start1,
- finish = finish1 - 1,
- text = '--[[',
- },
- {
- start = start2,
- finish = finish2 - 1,
- text = '--]]'
- },
- }
- }
- end
- PushComment {
- type = 'comment.clong',
- start = start1,
- finish = finish2 - 1,
- text = str,
- }
- end,
- CCommentPrefix = function (start, finish, commentFinish)
- if State.options.nonstandardSymbol and State.options.nonstandardSymbol['//'] then
- else
- PushError {
- type = 'ERR_COMMENT_PREFIX',
- start = start,
- finish = finish - 1,
- fix = {
- title = 'FIX_COMMENT_PREFIX',
- {
- start = start,
- finish = finish - 1,
- text = '--',
- },
- }
- }
- end
- PushComment {
- type = 'comment.cshort',
- start = start,
- finish = commentFinish - 1,
- text = '',
- }
- end,
- String = function (start, quote, str, finish)
- if quote == '`' then
- if State.options.nonstandardSymbol and State.options.nonstandardSymbol['`'] then
- else
- PushError {
- type = 'ERR_NONSTANDARD_SYMBOL',
- start = start,
- finish = finish - 1,
- info = {
- symbol = '"',
- },
- fix = {
- title = 'FIX_NONSTANDARD_SYMBOL',
- symbol = '"',
- {
- start = start,
- finish = start,
- text = '"',
- },
- {
- start = finish - 1,
- finish = finish - 1,
- text = '"',
- },
- }
- }
- end
- end
- return {
- type = 'string',
- start = start,
- finish = finish - 1,
- [1] = str,
- [2] = quote,
- }
- end,
- LongString = function (beforeEq, afterEq, str, missPos)
- if missPos then
- local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']'
- local s, _, w = str:find('(%][%=]*%])[%c%s]*$')
- if s then
- PushError {
- type = 'ERR_LSTRING_END',
- start = missPos - #str + s - 1,
- finish = missPos - #str + s + #w - 2,
- info = {
- symbol = endSymbol,
- },
- fix = {
- title = 'FIX_LSTRING_END',
- {
- start = missPos - #str + s - 1,
- finish = missPos - #str + s + #w - 2,
- text = endSymbol,
- }
- },
- }
- end
- PushError {
- type = 'MISS_SYMBOL',
- start = missPos,
- finish = missPos,
- info = {
- symbol = endSymbol,
- },
- fix = {
- title = 'ADD_LSTRING_END',
- {
- start = missPos,
- finish = missPos,
- text = endSymbol,
- }
- },
- }
- end
- return '[' .. ('='):rep(afterEq-beforeEq) .. '[', str
- end,
- Char10 = function (char)
- char = tonumber(char)
- if not char or char < 0 or char > 255 then
- return ''
- end
- return stringChar(char)
- end,
- Char16 = function (pos, char)
- if State.version == 'Lua 5.1' then
- PushError {
- type = 'ERR_ESC',
- start = pos-1,
- finish = pos,
- version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
- info = {
- version = State.version,
- }
- }
- return char
- end
- return stringChar(tonumber(char, 16))
- end,
- CharUtf8 = function (pos, char)
- if State.version ~= 'Lua 5.3'
- and State.version ~= 'Lua 5.4'
- and State.version ~= 'LuaJIT'
- then
- PushError {
- type = 'ERR_ESC',
- start = pos-3,
- finish = pos-2,
- version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
- info = {
- version = State.version,
- }
- }
- return char
- end
- if #char == 0 then
- PushError {
- type = 'UTF8_SMALL',
- start = pos-3,
- finish = pos,
- }
- return ''
- end
- local v = tonumber(char, 16)
- if not v then
- for i = 1, #char do
- if not tonumber(char:sub(i, i), 16) then
- PushError {
- type = 'MUST_X16',
- start = pos + i - 1,
- finish = pos + i - 1,
- }
- end
- end
- return ''
- end
- if State.version == 'Lua 5.4' then
- if v < 0 or v > 0x7FFFFFFF then
- PushError {
- type = 'UTF8_MAX',
- start = pos-3,
- finish = pos+#char,
- info = {
- min = '00000000',
- max = '7FFFFFFF',
- }
- }
- end
- else
- if v < 0 or v > 0x10FFFF then
- PushError {
- type = 'UTF8_MAX',
- start = pos-3,
- finish = pos+#char,
- version = v <= 0x7FFFFFFF and 'Lua 5.4' or nil,
- info = {
- min = '000000',
- max = '10FFFF',
- }
- }
- end
- end
- if v >= 0 and v <= 0x10FFFF then
- return utf8Char(v)
- end
- return ''
- end,
- Number = function (start, number, finish)
- local n = tonumber(number)
- if n then
- State.LastNumber = {
- type = mathType(n) == 'integer' and 'integer' or 'number',
- start = start,
- finish = finish - 1,
- [1] = n,
- }
- State.LastRaw = number
- return State.LastNumber
- else
- PushError {
- type = 'MALFORMED_NUMBER',
- start = start,
- finish = finish - 1,
- }
- State.LastNumber = {
- type = 'number',
- start = start,
- finish = finish - 1,
- [1] = 0,
- }
- State.LastRaw = number
- return State.LastNumber
- end
- end,
- FFINumber = function (start, symbol)
- local lastNumber = State.LastNumber
- if State.LastRaw:find('.', 1, true) then
- PushError {
- type = 'UNKNOWN_SYMBOL',
- start = start,
- finish = start + #symbol - 1,
- info = {
- symbol = symbol,
- }
- }
- lastNumber[1] = 0
- return
- end
- if State.version ~= 'LuaJIT' then
- PushError {
- type = 'UNSUPPORT_SYMBOL',
- start = start,
- finish = start + #symbol - 1,
- version = 'LuaJIT',
- info = {
- version = State.version,
- }
- }
- lastNumber[1] = 0
- end
- end,
- ImaginaryNumber = function (start, symbol)
- local lastNumber = State.LastNumber
- if State.version ~= 'LuaJIT' then
- PushError {
- type = 'UNSUPPORT_SYMBOL',
- start = start,
- finish = start + #symbol - 1,
- version = 'LuaJIT',
- info = {
- version = State.version,
- }
- }
- end
- lastNumber[1] = 0
- end,
- Integer2 = function (start, word)
- if State.version ~= 'LuaJIT' then
- PushError {
- type = 'UNSUPPORT_SYMBOL',
- start = start,
- finish = start + 1,
- version = 'LuaJIT',
- info = {
- version = State.version,
- }
- }
- end
- local num = 0
- for i = 1, #word do
- if word:sub(i, i) == '1' then
- num = num | (1 << (i - 1))
- end
- end
- return tostring(num)
- end,
- Name = function (start, str, finish)
- local isKeyWord
- if RESERVED[str] then
- isKeyWord = true
- elseif str == 'goto' then
- if State.version ~= 'Lua 5.1' and State.version ~= 'LuaJIT' then
- isKeyWord = true
- end
- end
- if isKeyWord then
- PushError {
- type = 'KEYWORD',
- start = start,
- finish = finish - 1,
- }
- end
- if not State.options.unicodeName and str:find '[\x80-\xff]' then
- PushError {
- type = 'UNICODE_NAME',
- start = start,
- finish = finish - 1,
- }
- end
- return {
- type = 'name',
- start = start,
- finish = finish - 1,
- [1] = str,
- }
- end,
- GetField = function (dot, field)
- local obj = {
- type = 'getfield',
- field = field,
- dot = dot,
- start = dot.start,
- finish = (field or dot).finish,
- }
- if field then
- field.type = 'field'
- field.parent = obj
- end
- return obj
- end,
- GetIndex = function (start, index, finish)
- local obj = {
- type = 'getindex',
- bstart = start,
- start = start,
- finish = finish - 1,
- index = index,
- }
- if index then
- index.parent = obj
- end
- return obj
- end,
- GetMethod = function (colon, method)
- local obj = {
- type = 'getmethod',
- method = method,
- colon = colon,
- start = colon.start,
- finish = (method or colon).finish,
- }
- if method then
- method.type = 'method'
- method.parent = obj
- end
- return obj
- end,
- Single = function (unit)
- unit.type = 'getname'
- return unit
- end,
- Simple = function (units)
- local last = units[1]
- for i = 2, #units do
- local current = units[i]
- current.node = last
- current.start = last.start
- last.next = current
- last = units[i]
- end
- return last
- end,
- SimpleCall = function (call)
- if call.type ~= 'call' and call.type ~= 'getmethod' then
- PushError {
- type = 'EXP_IN_ACTION',
- start = call.start,
- finish = call.finish,
- }
- end
- return call
- end,
- BinaryOp = function (start, op)
- if SymbolAlias[op] then
- if State.options.nonstandardSymbol and State.options.nonstandardSymbol[op] then
- else
- PushError {
- type = 'ERR_NONSTANDARD_SYMBOL',
- start = start,
- finish = start + #op - 1,
- info = {
- symbol = SymbolAlias[op],
- },
- fix = {
- title = 'FIX_NONSTANDARD_SYMBOL',
- symbol = SymbolAlias[op],
- {
- start = start,
- finish = start + #op - 1,
- text = SymbolAlias[op],
- },
- }
- }
- end
- op = SymbolAlias[op]
- end
- return {
- type = op,
- start = start,
- finish = start + #op - 1,
- }
- end,
- UnaryOp = function (start, op)
- if SymbolAlias[op] then
- if State.options.nonstandardSymbol and State.options.nonstandardSymbol[op] then
- else
- PushError {
- type = 'ERR_NONSTANDARD_SYMBOL',
- start = start,
- finish = start + #op - 1,
- info = {
- symbol = SymbolAlias[op],
- },
- fix = {
- title = 'FIX_NONSTANDARD_SYMBOL',
- symbol = SymbolAlias[op],
- {
- start = start,
- finish = start + #op - 1,
- text = SymbolAlias[op],
- },
- }
- }
- end
- op = SymbolAlias[op]
- end
- return {
- type = op,
- start = start,
- finish = start + #op - 1,
- }
- end,
- Unary = function (first, ...)
- if not ... then
- return nil
- end
- local list = {first, ...}
- local e = list[#list]
- for i = #list - 1, 1, -1 do
- local op = list[i]
- checkOpVersion(op)
- e = {
- type = 'unary',
- op = op,
- start = op.start,
- finish = e.finish,
- [1] = e,
- }
- end
- return e
- end,
- SubBinary = function (op, symb)
- if symb then
- return op, symb
- end
- PushError {
- type = 'MISS_EXP',
- start = op.start,
- finish = op.finish,
- }
- end,
- Binary = function (first, op, second, ...)
- if not first then
- return second
- end
- if not op then
- return first
- end
- if not ... then
- checkOpVersion(op)
- return {
- type = 'binary',
- op = op,
- start = first.start,
- finish = second.finish,
- [1] = first,
- [2] = second,
- }
- end
- local list = {first, op, second, ...}
- local ops = {}
- for i = 2, #list, 2 do
- ops[#ops+1] = i
- end
- tableSort(ops, function (a, b)
- local op1 = list[a]
- local op2 = list[b]
- local lv1 = BinaryLevel[op1.type]
- local lv2 = BinaryLevel[op2.type]
- if lv1 == lv2 then
- local forward = BinaryForward[lv1]
- if forward then
- return op1.start > op2.start
- else
- return op1.start < op2.start
- end
- else
- return lv1 < lv2
- end
- end)
- local final
- for i = #ops, 1, -1 do
- local n = ops[i]
- local op = list[n]
- local left = list[n-1]
- local right = list[n+1]
- local exp = {
- type = 'binary',
- op = op,
- start = left.start,
- finish = right and right.finish or op.finish,
- [1] = left,
- [2] = right,
- }
- local leftIndex, rightIndex
- if list[left] then
- leftIndex = list[left[1]]
- else
- leftIndex = n - 1
- end
- if list[right] then
- rightIndex = list[right[2]]
- else
- rightIndex = n + 1
- end
-
- list[leftIndex] = exp
- list[rightIndex] = exp
- list[left] = leftIndex
- list[right] = rightIndex
- list[exp] = n
- final = exp
-
- checkOpVersion(op)
- end
- return final
- end,
- Paren = function (start, exp, finish)
- if exp and exp.type == 'paren' then
- exp.start = start
- exp.finish = finish - 1
- return exp
- end
- return {
- type = 'paren',
- start = start,
- finish = finish - 1,
- exp = exp
- }
- end,
- VarArgs = function (dots)
- dots.type = 'varargs'
- return dots
- end,
- PackLoopArgs = function (start, list, finish)
- local list = packList(start, list, finish)
- if #list == 0 then
- PushError {
- type = 'MISS_LOOP_MIN',
- start = finish,
- finish = finish,
- }
- elseif #list == 1 then
- PushError {
- type = 'MISS_LOOP_MAX',
- start = finish,
- finish = finish,
- }
- end
- return list
- end,
- PackInNameList = function (start, list, finish)
- local list = packList(start, list, finish)
- if #list == 0 then
- PushError {
- type = 'MISS_NAME',
- start = start,
- finish = finish,
- }
- end
- return list
- end,
- PackInExpList = function (start, list, finish)
- local list = packList(start, list, finish)
- if #list == 0 then
- PushError {
- type = 'MISS_EXP',
- start = start,
- finish = finish,
- }
- end
- return list
- end,
- PackExpList = function (start, list, finish)
- local list = packList(start, list, finish)
- return list
- end,
- PackNameList = function (start, list, finish)
- local list = packList(start, list, finish)
- return list
- end,
- Call = function (start, args, finish)
- return createCall(args, start, finish-1)
- end,
- COMMA = function (start)
- return {
- type = ',',
- start = start,
- finish = start,
- }
- end,
- SEMICOLON = function (start)
- return {
- type = ';',
- start = start,
- finish = start,
- }
- end,
- DOTS = function (start)
- return {
- type = '...',
- start = start,
- finish = start + 2,
- }
- end,
- COLON = function (start)
- return {
- type = ':',
- start = start,
- finish = start,
- }
- end,
- ASSIGN = function (start, symbol)
- if State.options.nonstandardSymbol and State.options.nonstandardSymbol[symbol] then
- else
- PushError {
- type = 'UNSUPPORT_SYMBOL',
- start = start,
- finish = start + #symbol - 1,
- info = {
- version = 'Lua',
- }
- }
- end
- end,
- DOT = function (start)
- return {
- type = '.',
- start = start,
- finish = start,
- }
- end,
- Function = function (functionStart, functionFinish, name, args, actions, endStart, endFinish)
- actions.type = 'function'
- actions.start = functionStart
- actions.finish = endFinish - 1
- actions.args = args
- actions.keyword= {
- functionStart, functionFinish - 1,
- endStart, endFinish - 1,
- }
- checkMissEnd(functionStart)
- if not name then
- return actions
- end
- if name.type == 'getname' then
- name.type = 'setname'
- name.value = actions
- elseif name.type == 'getfield' then
- name.type = 'setfield'
- name.value = actions
- elseif name.type == 'getmethod' then
- name.type = 'setmethod'
- name.value = actions
- elseif name.type == 'getindex' then
- name.type = 'setfield'
- name.value = actions
- PushError {
- type = 'INDEX_IN_FUNC_NAME',
- start = name.bstart,
- finish = name.finish,
- }
- end
- name.range = actions.finish
- name.vstart = functionStart
- return name
- end,
- LocalFunction = function (start, name)
- if name.type == 'function' then
- PushError {
- type = 'MISS_NAME',
- start = name.keyword[2] + 1,
- finish = name.keyword[2] + 1,
- }
- return name
- end
- if name.type ~= 'setname' then
- PushError {
- type = 'UNEXPECT_LFUNC_NAME',
- start = name.start,
- finish = name.finish,
- }
- return name
- end
-
- local loc = createLocal(name, name.start, name.value)
- loc.localfunction = true
- loc.vstart = name.value.start
- return name
- end,
- NamedFunction = function (name)
- if name.type == 'function' then
- PushError {
- type = 'MISS_NAME',
- start = name.keyword[2] + 1,
- finish = name.keyword[2] + 1,
- }
- end
- return name
- end,
- ExpFunction = function (func)
- if func.type ~= 'function' then
- PushError {
- type = 'UNEXPECT_EFUNC_NAME',
- start = func.start,
- finish = func.finish,
- }
- return func.value
- end
- return func
- end,
- Table = function (start, tbl, finish)
- tbl.type = 'table'
- tbl.start = start
- tbl.finish = finish - 1
- local wantField = true
- local lastStart = start + 1
- local fieldCount = 0
- local n = 0
- for i = 1, #tbl do
- local field = tbl[i]
- if field.type == ',' or field.type == ';' then
- if wantField then
- PushError {
- type = 'MISS_EXP',
- start = lastStart,
- finish = field.start - 1,
- }
- end
- wantField = true
- lastStart = field.finish + 1
- else
- if not wantField then
- PushError {
- type = 'MISS_SEP_IN_TABLE',
- start = lastStart,
- finish = field.start - 1,
- }
- end
- wantField = false
- lastStart = field.finish + 1
- fieldCount = fieldCount + 1
- tbl[fieldCount] = field
- if field.type == 'tableexp' then
- n = n + 1
- field.tindex = n
- end
- end
- end
- for i = fieldCount + 1, #tbl do
- tbl[i] = nil
- end
- return tbl
- end,
- NewField = function (start, field, value, finish)
- local obj = {
- type = 'tablefield',
- start = start,
- finish = finish-1,
- field = field,
- value = value,
- }
- if field then
- field.type = 'field'
- field.parent = obj
- end
- return obj
- end,
- NewIndex = function (start, index, value, finish)
- local obj = {
- type = 'tableindex',
- start = start,
- finish = finish-1,
- index = index,
- value = value,
- }
- if index then
- index.parent = obj
- end
- return obj
- end,
- TableExp = function (start, value, finish)
- if not value then
- return
- end
- local obj = {
- type = 'tableexp',
- start = start,
- finish = finish-1,
- value = value,
- }
- return obj
- end,
- FuncArgs = function (start, args, finish)
- args.type = 'funcargs'
- args.start = start
- args.finish = finish - 1
- local lastStart = start + 1
- local wantName = true
- local argCount = 0
- for i = 1, #args do
- local arg = args[i]
- local argAst = arg
- if argAst.type == ',' then
- if wantName then
- PushError {
- type = 'MISS_NAME',
- start = lastStart,
- finish = argAst.start-1,
- }
- end
- wantName = true
- else
- if not wantName then
- PushError {
- type = 'MISS_SYMBOL',
- start = lastStart-1,
- finish = argAst.start-1,
- info = {
- symbol = ',',
- }
- }
- end
- wantName = false
- argCount = argCount + 1
-
- if argAst.type == '...' then
- args[argCount] = arg
- if i < #args then
- local a = args[i+1]
- local b = args[#args]
- PushError {
- type = 'ARGS_AFTER_DOTS',
- start = a.start,
- finish = b.finish,
- }
- end
- break
- else
- args[argCount] = createLocal(arg, arg.start)
- end
- end
- lastStart = argAst.finish + 1
- end
- for i = argCount + 1, #args do
- args[i] = nil
- end
- if wantName and argCount > 0 then
- PushError {
- type = 'MISS_NAME',
- start = lastStart,
- finish = finish - 1,
- }
- end
- return args
- end,
- Set = function (start, keys, eqFinish, values, finish)
- for i = 1, #keys do
- local key = keys[i]
- if key.type == 'getname' then
- key.type = 'setname'
- key.value = getValue(values, i)
- elseif key.type == 'getfield' then
- key.type = 'setfield'
- key.value = getValue(values, i)
- elseif key.type == 'getindex' then
- key.type = 'setindex'
- key.value = getValue(values, i)
- else
- PushError {
- type = 'UNEXPECT_SYMBOL',
- start = eqFinish - 1,
- finish = eqFinish - 1,
- info = {
- symbol = '=',
- }
- }
- end
- if key.value then
- key.range = key.value.finish
- end
- end
- if values then
- for i = #keys+1, #values do
- local value = values[i]
- PushDiag('redundant-value', {
- start = value.start,
- finish = value.finish,
- max = #keys,
- passed = #values,
- })
- end
- end
- return tableUnpack(keys)
- end,
- LocalAttr = function (attrs)
- if #attrs == 0 then
- return nil
- end
- for i = 1, #attrs do
- local attr = attrs[i]
- local attrAst = attr
- attrAst.type = 'localattr'
- if State.version ~= 'Lua 5.4' then
- PushError {
- type = 'UNSUPPORT_SYMBOL',
- start = attrAst.start,
- finish = attrAst.finish,
- version = 'Lua 5.4',
- info = {
- version = State.version,
- }
- }
- elseif attrAst[1] ~= 'const' and attrAst[1] ~= 'close' then
- PushError {
- type = 'UNKNOWN_TAG',
- start = attrAst.start,
- finish = attrAst.finish,
- info = {
- tag = attrAst[1],
- }
- }
- elseif i > 1 then
- PushError {
- type = 'MULTI_TAG',
- start = attrAst.start,
- finish = attrAst.finish,
- info = {
- tag = attrAst[1],
- }
- }
- end
- end
- attrs.start = attrs[1].start
- attrs.finish = attrs[#attrs].finish
- return attrs
- end,
- LocalName = function (name, attrs)
- if not name then
- return
- end
- name.attrs = attrs
- return name
- end,
- Local = function (start, keys, values, finish)
- for i = 1, #keys do
- local key = keys[i]
- local attrs = key.attrs
- key.attrs = nil
- local value = getValue(values, i)
- createLocal(key, finish, value, attrs)
- end
- if values then
- for i = #keys+1, #values do
- local value = values[i]
- PushDiag('redundant-value', {
- start = value.start,
- finish = value.finish,
- max = #keys,
- passed = #values,
- })
- end
- end
- return tableUnpack(keys)
- end,
- Do = function (start, actions, endA, endB)
- actions.type = 'do'
- actions.start = start
- actions.finish = endB - 1
- actions.keyword= {
- start, start + #'do' - 1,
- endA , endB - 1,
- }
- checkMissEnd(start)
- return actions
- end,
- Break = function (start, finish)
- return {
- type = 'break',
- start = start,
- finish = finish - 1,
- }
- end,
- Return = function (start, exps, finish)
- exps.type = 'return'
- exps.start = start
- exps.finish = finish - 1
- return exps
- end,
- Label = function (start, name, finish)
- if State.version == 'Lua 5.1' then
- PushError {
- type = 'UNSUPPORT_SYMBOL',
- start = start,
- finish = finish - 1,
- version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
- info = {
- version = State.version,
- }
- }
- return
- end
- if not name then
- return
- end
- name.type = 'label'
- return name
- end,
- GoTo = function (start, name, finish)
- if State.version == 'Lua 5.1' then
- PushError {
- type = 'UNSUPPORT_SYMBOL',
- start = start,
- finish = finish - 1,
- version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
- info = {
- version = State.version,
- }
- }
- return
- end
- if not name then
- return
- end
- name.type = 'goto'
- return name
- end,
- IfBlock = function (ifStart, ifFinish, exp, thenStart, thenFinish, actions, finish)
- actions.type = 'ifblock'
- actions.start = ifStart
- actions.finish = finish - 1
- actions.filter = exp
- actions.keyword= {
- ifStart, ifFinish - 1,
- thenStart, thenFinish - 1,
- }
- return actions
- end,
- ElseIfBlock = function (elseifStart, elseifFinish, exp, thenStart, thenFinish, actions, finish)
- actions.type = 'elseifblock'
- actions.start = elseifStart
- actions.finish = finish - 1
- actions.filter = exp
- actions.keyword= {
- elseifStart, elseifFinish - 1,
- thenStart, thenFinish - 1,
- }
- return actions
- end,
- ElseBlock = function (elseStart, elseFinish, actions, finish)
- actions.type = 'elseblock'
- actions.start = elseStart
- actions.finish = finish - 1
- actions.keyword= {
- elseStart, elseFinish - 1,
- }
- return actions
- end,
- If = function (start, blocks, endStart, endFinish)
- blocks.type = 'if'
- blocks.start = start
- blocks.finish = endFinish - 1
- local hasElse
- for i = 1, #blocks do
- local block = blocks[i]
- if i == 1 and block.type ~= 'ifblock' then
- PushError {
- type = 'MISS_SYMBOL',
- start = block.start,
- finish = block.start,
- info = {
- symbol = 'if',
- }
- }
- end
- if hasElse then
- PushError {
- type = 'BLOCK_AFTER_ELSE',
- start = block.start,
- finish = block.finish,
- }
- end
- if block.type == 'elseblock' then
- hasElse = true
- end
- end
- checkMissEnd(start)
- return blocks
- end,
- Loop = function (forA, forB, arg, steps, doA, doB, blockStart, block, endA, endB)
- local loc = createLocal(arg, blockStart, steps[1])
- block.type = 'loop'
- block.start = forA
- block.finish = endB - 1
- block.loc = loc
- block.max = steps[2]
- block.step = steps[3]
- block.keyword= {
- forA, forB - 1,
- doA , doB - 1,
- endA, endB - 1,
- }
- checkMissEnd(forA)
- return block
- end,
- In = function (forA, forB, keys, inA, inB, exp, doA, doB, blockStart, block, endA, endB)
- local func = tableRemove(exp, 1)
- block.type = 'in'
- block.start = forA
- block.finish = endB - 1
- block.keys = keys
- block.keyword= {
- forA, forB - 1,
- inA , inB - 1,
- doA , doB - 1,
- endA, endB - 1,
- }
-
- local values
- if func then
- local call = createCall(exp, func.finish + 1, exp.finish)
- if #exp == 0 then
- exp[1] = getSelect(func, 2)
- exp[2] = getSelect(func, 3)
- exp[3] = getSelect(func, 4)
- end
- call.node = func
- call.start = inA
- call.finish = doB - 1
- func.next = call
- func.iterator = true
- values = { call }
- keys.range = call.finish
- end
- for i = 1, #keys do
- local loc = keys[i]
- if values then
- createLocal(loc, blockStart, getValue(values, i))
- else
- createLocal(loc, blockStart)
- end
- end
- checkMissEnd(forA)
- return block
- end,
- While = function (whileA, whileB, filter, doA, doB, block, endA, endB)
- block.type = 'while'
- block.start = whileA
- block.finish = endB - 1
- block.filter = filter
- block.keyword= {
- whileA, whileB - 1,
- doA , doB - 1,
- endA , endB - 1,
- }
- checkMissEnd(whileA)
- return block
- end,
- Repeat = function (repeatA, repeatB, block, untilA, untilB, filter, finish)
- block.type = 'repeat'
- block.start = repeatA
- block.finish = finish
- block.filter = filter
- block.keyword= {
- repeatA, repeatB - 1,
- untilA , untilB - 1,
- }
- return block
- end,
- RTContinue = function (_, pos, ...)
- if State.options.nonstandardSymbol and State.options.nonstandardSymbol['continue'] then
- return pos, ...
- else
- return false
- end
- end,
- Continue = function (start, finish)
- return {
- type = 'nonstandardSymbol.continue',
- start = start,
- finish = finish - 1,
- }
- end,
- Lua = function (start, actions, finish)
- actions.type = 'main'
- actions.start = start
- actions.finish = finish - 1
- return actions
- end,
-
- -- 捕获错误
- UnknownSymbol = function (start, symbol)
- PushError {
- type = 'UNKNOWN_SYMBOL',
- start = start,
- finish = start + #symbol - 1,
- info = {
- symbol = symbol,
- }
- }
- end,
- UnknownAction = function (start, symbol)
- PushError {
- type = 'UNKNOWN_SYMBOL',
- start = start,
- finish = start + #symbol - 1,
- info = {
- symbol = symbol,
- }
- }
- end,
- DirtyName = function (pos)
- PushError {
- type = 'MISS_NAME',
- start = pos,
- finish = pos,
- }
- return nil
- end,
- DirtyExp = function (pos)
- PushError {
- type = 'MISS_EXP',
- start = pos,
- finish = pos,
- }
- return nil
- end,
- MissExp = function (pos)
- PushError {
- type = 'MISS_EXP',
- start = pos,
- finish = pos,
- }
- end,
- MissExponent = function (start, finish)
- PushError {
- type = 'MISS_EXPONENT',
- start = start,
- finish = finish - 1,
- }
- end,
- MissQuote1 = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = '"'
- }
- }
- end,
- MissQuote2 = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = "'"
- }
- }
- end,
- MissQuote3 = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = "`"
- }
- }
- end,
- MissEscX = function (pos)
- PushError {
- type = 'MISS_ESC_X',
- start = pos-2,
- finish = pos+1,
- }
- end,
- MissTL = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = '{',
- }
- }
- end,
- MissTR = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = '}',
- }
- }
- end,
- MissBR = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = ']',
- }
- }
- end,
- MissPL = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = '(',
- }
- }
- end,
- MissPR = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = ')',
- }
- }
- end,
- ErrEsc = function (pos)
- PushError {
- type = 'ERR_ESC',
- start = pos-1,
- finish = pos,
- }
- end,
- MustX16 = function (pos, str)
- PushError {
- type = 'MUST_X16',
- start = pos,
- finish = pos + #str - 1,
- }
- end,
- MissAssign = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = '=',
- }
- }
- end,
- MissTableSep = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = ','
- }
- }
- end,
- MissField = function (pos)
- PushError {
- type = 'MISS_FIELD',
- start = pos,
- finish = pos,
- }
- end,
- MissMethod = function (pos)
- PushError {
- type = 'MISS_METHOD',
- start = pos,
- finish = pos,
- }
- end,
- MissLabel = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = '::',
- }
- }
- end,
- MissEnd = function (pos)
- State.MissEndErr = PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = 'end',
- }
- }
- return pos, pos
- end,
- MissDo = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = 'do',
- }
- }
- return pos, pos
- end,
- MissComma = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = ',',
- }
- }
- end,
- MissIn = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = 'in',
- }
- }
- return pos, pos
- end,
- MissUntil = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = 'until',
- }
- }
- return pos, pos
- end,
- MissThen = function (pos)
- PushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = 'then',
- }
- }
- return pos, pos
- end,
- MissName = function (pos)
- PushError {
- type = 'MISS_NAME',
- start = pos,
- finish = pos,
- }
- end,
- ExpInAction = function (start, exp, finish)
- PushError {
- type = 'EXP_IN_ACTION',
- start = start,
- finish = finish - 1,
- }
- -- 当exp为nil时,不能返回任何值,否则会产生带洞的actionlist
- if exp then
- return exp
- else
- return
- end
- end,
- MissIf = function (start, block)
- PushError {
- type = 'MISS_SYMBOL',
- start = start,
- finish = start,
- info = {
- symbol = 'if',
- }
- }
- return block
- end,
- MissGT = function (start)
- PushError {
- type = 'MISS_SYMBOL',
- start = start,
- finish = start,
- info = {
- symbol = '>'
- }
- }
- end,
- ErrAssign = function (start, finish)
- PushError {
- type = 'ERR_ASSIGN_AS_EQ',
- start = start,
- finish = finish - 1,
- fix = {
- title = 'FIX_ASSIGN_AS_EQ',
- {
- start = start,
- finish = finish - 1,
- text = '=',
- }
- }
- }
- end,
- ErrEQ = function (start, finish)
- PushError {
- type = 'ERR_EQ_AS_ASSIGN',
- start = start,
- finish = finish - 1,
- fix = {
- title = 'FIX_EQ_AS_ASSIGN',
- {
- start = start,
- finish = finish - 1,
- text = '==',
- }
- }
- }
- return '=='
- end,
- ErrUEQ = function (start, finish)
- PushError {
- type = 'ERR_UEQ',
- start = start,
- finish = finish - 1,
- fix = {
- title = 'FIX_UEQ',
- {
- start = start,
- finish = finish - 1,
- text = '~=',
- }
- }
- }
- return '=='
- end,
- ErrThen = function (start, finish)
- PushError {
- type = 'ERR_THEN_AS_DO',
- start = start,
- finish = finish - 1,
- fix = {
- title = 'FIX_THEN_AS_DO',
- {
- start = start,
- finish = finish - 1,
- text = 'then',
- }
- }
- }
- return start, finish
- end,
- ErrDo = function (start, finish)
- PushError {
- type = 'ERR_DO_AS_THEN',
- start = start,
- finish = finish - 1,
- fix = {
- title = 'FIX_DO_AS_THEN',
- {
- start = start,
- finish = finish - 1,
- text = 'do',
- }
- }
- }
- return start, finish
- end,
- MissSpaceBetween = function (start)
- PushError {
- type = 'MISS_SPACE_BETWEEN',
- start = start,
- finish = start + 1,
- fix = {
- title = 'FIX_INSERT_SPACE',
- {
- start = start + 1,
- finish = start,
- text = ' ',
- }
- }
- }
- end,
- CallArgSnip = function (name, tailStart, tailSymbol)
- PushError {
- type = 'UNEXPECT_SYMBOL',
- start = tailStart,
- finish = tailStart,
- info = {
- symbol = tailSymbol,
- }
- }
- return name
- end
-}
-
-local function init(state)
- State = state
- PushError = state.pushError
- PushDiag = state.pushDiag
- PushComment = state.pushComment
-end
-
-local function close()
- State = DefaultState
- PushError = function (...) end
- PushDiag = function (...) end
- PushComment = function (...) end
-end
-
-return {
- defs = Defs,
- init = init,
- close = close,
-}
diff --git a/script/parser/calcline.lua b/script/parser/calcline.lua
deleted file mode 100644
index 2e944167..00000000
--- a/script/parser/calcline.lua
+++ /dev/null
@@ -1,94 +0,0 @@
-local m = require 'lpeglabel'
-local util = require 'utility'
-
-local row
-local fl
-local NL = (m.P'\r\n' + m.S'\r\n') * m.Cp() / function (pos)
- row = row + 1
- fl = pos
-end
-local ROWCOL = (NL + m.P(1))^0
-local function rowcol(str, n)
- row = 1
- fl = 1
- ROWCOL:match(str:sub(1, n))
- local col = n - fl + 1
- return row, col
-end
-
-local function rowcol_utf8(str, n)
- row = 1
- fl = 1
- ROWCOL:match(str:sub(1, n))
- return row, util.utf8Len(str, fl, n)
-end
-
-local function position(str, _row, _col)
- local cur = 1
- local row = 1
- while true do
- if row == _row then
- return cur + _col - 1
- elseif row > _row then
- return cur - 1
- end
- local pos = str:find('[\r\n]', cur)
- if not pos then
- return #str
- end
- row = row + 1
- if str:sub(pos, pos+1) == '\r\n' then
- cur = pos + 2
- else
- cur = pos + 1
- end
- end
-end
-
-local function position_utf8(str, _row, _col)
- local cur = 1
- local row = 1
- while true do
- if row == _row then
- return utf8.offset(str, _col, cur)
- elseif row > _row then
- return cur - 1
- end
- local pos = str:find('[\r\n]', cur)
- if not pos then
- return #str
- end
- row = row + 1
- if str:sub(pos, pos+1) == '\r\n' then
- cur = pos + 2
- else
- cur = pos + 1
- end
- end
-end
-
-local NL = m.P'\r\n' + m.S'\r\n'
-
-local function line(str, row)
- local count = 0
- local res
- local LINE = m.Cmt((1 - NL)^0, function (_, _, c)
- count = count + 1
- if count == row then
- res = c
- return false
- end
- return true
- end)
- local MATCH = (LINE * NL)^0 * LINE
- MATCH:match(str)
- return res
-end
-
-return {
- rowcol = rowcol,
- rowcol_utf8 = rowcol_utf8,
- position = position,
- position_utf8 = position_utf8,
- line = line,
-}
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index 752728d1..915a2764 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -1,12 +1,109 @@
-local guide = require 'parser.guide'
-local parse = require 'parser.parse'
-local newparser = require 'parser.newparser'
-local type = type
-local tableInsert = table.insert
-local pairs = pairs
-local os = os
-
-local specials = {
+local tokens = require 'parser.tokens'
+local guide = require 'parser.guide'
+
+local sbyte = string.byte
+local sfind = string.find
+local smatch = string.match
+local sgsub = string.gsub
+local ssub = string.sub
+local schar = string.char
+local supper = string.upper
+local uchar = utf8.char
+local tconcat = table.concat
+local tinsert = table.insert
+local tointeger = math.tointeger
+local tonumber = tonumber
+local maxinteger = math.maxinteger
+local assert = assert
+
+_ENV = nil
+
+---@alias parser.position integer
+
+---@param str string
+---@return table<integer, boolean>
+local function stringToCharMap(str)
+ local map = {}
+ local pos = 1
+ while pos <= #str do
+ local byte = sbyte(str, pos, pos)
+ map[schar(byte)] = true
+ pos = pos + 1
+ if ssub(str, pos, pos) == '-'
+ and pos < #str then
+ pos = pos + 1
+ local byte2 = sbyte(str, pos, pos)
+ assert(byte < byte2)
+ for b = byte + 1, byte2 do
+ map[schar(b)] = true
+ end
+ pos = pos + 1
+ end
+ end
+ return map
+end
+
+local CharMapNumber = stringToCharMap '0-9'
+local CharMapN16 = stringToCharMap 'xX'
+local CharMapN2 = stringToCharMap 'bB'
+local CharMapE10 = stringToCharMap 'eE'
+local CharMapE16 = stringToCharMap 'pP'
+local CharMapSign = stringToCharMap '+-'
+local CharMapSB = stringToCharMap 'ao|~&=<>.*/%^+-'
+local CharMapSU = stringToCharMap 'n#~!-'
+local CharMapSimple = stringToCharMap '.:([\'"{'
+local CharMapStrSH = stringToCharMap '\'"`'
+local CharMapStrLH = stringToCharMap '['
+local CharMapTSep = stringToCharMap ',;'
+local CharMapWord = stringToCharMap '_a-zA-Z\x80-\xff'
+
+local EscMap = {
+ ['a'] = '\a',
+ ['b'] = '\b',
+ ['f'] = '\f',
+ ['n'] = '\n',
+ ['r'] = '\r',
+ ['t'] = '\t',
+ ['v'] = '\v',
+ ['\\'] = '\\',
+ ['\''] = '\'',
+ ['\"'] = '\"',
+}
+
+local NLMap = {
+ ['\n'] = true,
+ ['\r'] = true,
+ ['\r\n'] = true,
+}
+
+local LineMulti = 10000
+
+-- goto 单独处理
+local KeyWord = {
+ ['and'] = true,
+ ['break'] = true,
+ ['do'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['end'] = true,
+ ['false'] = true,
+ ['for'] = true,
+ ['function'] = true,
+ ['if'] = true,
+ ['in'] = true,
+ ['local'] = true,
+ ['nil'] = true,
+ ['not'] = true,
+ ['or'] = true,
+ ['repeat'] = true,
+ ['return'] = true,
+ ['then'] = true,
+ ['true'] = true,
+ ['until'] = true,
+ ['while'] = true,
+}
+
+local Specials = {
['_G'] = true,
['rawset'] = true,
['rawget'] = true,
@@ -18,491 +115,622 @@ local specials = {
['xpcall'] = true,
['pairs'] = true,
['ipairs'] = true,
+ ['assert'] = true,
+ ['error'] = true,
+ ['type'] = true,
}
-_ENV = nil
+local UnarySymbol = {
+ ['not'] = 11,
+ ['#'] = 11,
+ ['~'] = 11,
+ ['-'] = 11,
+}
+
+local BinarySymbol = {
+ ['or'] = 1,
+ ['and'] = 2,
+ ['<='] = 3,
+ ['>='] = 3,
+ ['<'] = 3,
+ ['>'] = 3,
+ ['~='] = 3,
+ ['=='] = 3,
+ ['|'] = 4,
+ ['~'] = 5,
+ ['&'] = 6,
+ ['<<'] = 7,
+ ['>>'] = 7,
+ ['..'] = 8,
+ ['+'] = 9,
+ ['-'] = 9,
+ ['*'] = 10,
+ ['//'] = 10,
+ ['/'] = 10,
+ ['%'] = 10,
+ ['^'] = 12,
+}
+
+local BinaryAlias = {
+ ['&&'] = 'and',
+ ['||'] = 'or',
+ ['!='] = '~=',
+}
+
+local BinaryActionAlias = {
+ ['='] = '==',
+}
+
+local UnaryAlias = {
+ ['!'] = 'not',
+}
+
+local SymbolForward = {
+ [01] = true,
+ [02] = true,
+ [03] = true,
+ [04] = true,
+ [05] = true,
+ [06] = true,
+ [07] = true,
+ [08] = false,
+ [09] = true,
+ [10] = true,
+ [11] = true,
+ [12] = false,
+}
+
+local GetToSetMap = {
+ ['getglobal'] = 'setglobal',
+ ['getlocal'] = 'setlocal',
+ ['getfield'] = 'setfield',
+ ['getindex'] = 'setindex',
+ ['getmethod'] = 'setmethod',
+}
+
+local ChunkFinishMap = {
+ ['end'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['in'] = true,
+ ['then'] = true,
+ ['until'] = true,
+ [';'] = true,
+ [']'] = true,
+ [')'] = true,
+ ['}'] = true,
+}
+
+local ChunkStartMap = {
+ ['do'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['for'] = true,
+ ['function'] = true,
+ ['if'] = true,
+ ['local'] = true,
+ ['repeat'] = true,
+ ['return'] = true,
+ ['then'] = true,
+ ['until'] = true,
+ ['while'] = true,
+}
+
+local ListFinishMap = {
+ ['end'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['in'] = true,
+ ['then'] = true,
+ ['do'] = true,
+ ['until'] = true,
+ ['for'] = true,
+ ['if'] = true,
+ ['local'] = true,
+ ['repeat'] = true,
+ ['return'] = true,
+ ['while'] = true,
+}
+
+local State, Lua, Line, LineOffset, Chunk, Tokens, Index, LastTokenFinish, Mode, LocalCount
local LocalLimit = 200
-local pushError, Compile, CompileBlock, Block, GoToTag, ENVMode, Compiled, LocalCount, Version, Root, Options
-local function addRef(node, obj)
- if not node.ref then
- node.ref = {}
- end
- node.ref[#node.ref+1] = obj
- obj.node = node
-end
+local parseExp, parseAction
+
+local pushError
local function addSpecial(name, obj)
- if not Root.specials then
- Root.specials = {}
+ if not State.specials then
+ State.specials = {}
end
- if not Root.specials[name] then
- Root.specials[name] = {}
+ if not State.specials[name] then
+ State.specials[name] = {}
end
- Root.specials[name][#Root.specials[name]+1] = obj
+ State.specials[name][#State.specials[name]+1] = obj
obj.special = name
end
-local vmMap = {
- ['getname'] = function (obj)
- local loc = guide.getLocal(obj, obj[1], obj.start)
- if loc then
- obj.type = 'getlocal'
- obj.loc = loc
- addRef(loc, obj)
- if loc.special then
- addSpecial(loc.special, obj)
- end
- else
- obj.type = 'getglobal'
- local node = guide.getLocal(obj, ENVMode, obj.start)
- if node then
- addRef(node, obj)
- end
- local name = obj[1]
- if specials[name] then
- addSpecial(name, obj)
- elseif Options and Options.special then
- local asName = Options.special[name]
- if specials[asName] then
- addSpecial(asName, obj)
- end
- end
+---@param offset integer
+---@param leftOrRight '"left"'|'"right"'
+local function getPosition(offset, leftOrRight)
+ if not offset or offset > #Lua then
+ return LineMulti * Line + #Lua - LineOffset + 1
+ end
+ if leftOrRight == 'left' then
+ return LineMulti * Line + offset - LineOffset
+ else
+ return LineMulti * Line + offset - LineOffset + 1
+ end
+end
+
+---@return string? word
+---@return parser.position? startPosition
+---@return parser.position? finishPosition
+local function peekWord()
+ local word = Tokens[Index + 1]
+ if not word then
+ return nil
+ end
+ if not CharMapWord[ssub(word, 1, 1)] then
+ return nil
+ end
+ local startPos = getPosition(Tokens[Index] , 'left')
+ local finishPos = getPosition(Tokens[Index] + #word - 1, 'right')
+ return word, startPos, finishPos
+end
+
+local function lastRightPosition()
+ if Index < 2 then
+ return 0
+ end
+ local token = Tokens[Index - 1]
+ if NLMap[token] then
+ return LastTokenFinish
+ elseif token then
+ return getPosition(Tokens[Index - 2] + #token - 1, 'right')
+ else
+ return getPosition(#Lua, 'right')
+ end
+end
+
+local function missSymbol(symbol, start, finish)
+ pushError {
+ type = 'MISS_SYMBOL',
+ start = start or lastRightPosition(),
+ finish = finish or start or lastRightPosition(),
+ info = {
+ symbol = symbol,
+ }
+ }
+end
+
+local function missExp()
+ pushError {
+ type = 'MISS_EXP',
+ start = lastRightPosition(),
+ finish = lastRightPosition(),
+ }
+end
+
+local function missName(pos)
+ pushError {
+ type = 'MISS_NAME',
+ start = pos or lastRightPosition(),
+ finish = pos or lastRightPosition(),
+ }
+end
+
+local function missEnd(relatedStart, relatedFinish)
+ pushError {
+ type = 'MISS_SYMBOL',
+ start = lastRightPosition(),
+ finish = lastRightPosition(),
+ info = {
+ symbol = 'end',
+ related = {
+ {
+ start = relatedStart,
+ finish = relatedFinish,
+ }
+ }
+ }
+ }
+ pushError {
+ type = 'MISS_END',
+ start = relatedStart,
+ finish = relatedFinish,
+ }
+end
+
+local function unknownSymbol(start, finish, word)
+ local token = word or Tokens[Index + 1]
+ if not token then
+ return false
+ end
+ pushError {
+ type = 'UNKNOWN_SYMBOL',
+ start = start or getPosition(Tokens[Index], 'left'),
+ finish = finish or getPosition(Tokens[Index] + #token - 1, 'right'),
+ info = {
+ symbol = token,
+ }
+ }
+ return true
+end
+
+local function skipUnknownSymbol(stopSymbol)
+ if unknownSymbol() then
+ Index = Index + 2
+ return true
+ end
+ return false
+end
+
+local function skipNL()
+ local token = Tokens[Index + 1]
+ if NLMap[token] then
+ if Index >= 2 and not NLMap[Tokens[Index - 1]] then
+ LastTokenFinish = getPosition(Tokens[Index - 2] + #Tokens[Index - 1] - 1, 'right')
end
- return obj
- end,
- ['getfield'] = function (obj)
- Compile(obj.node, obj)
- end,
- ['call'] = function (obj)
- Compile(obj.node, obj)
- if obj.node and obj.node.type == 'getmethod' then
- if not obj.args then
- obj.args = {
- type = 'callargs',
- start = obj.start,
- finish = obj.finish,
- parent = obj,
- }
- end
- local newNode = {}
- for k, v in pairs(obj.node.node) do
- newNode[k] = v
- end
- newNode.mirror = obj.node.node
- newNode.dummy = true
- newNode.parent = obj.args
- obj.node.node.mirror = newNode
- tableInsert(obj.args, 1, newNode)
- Compiled[newNode] = true
- end
- Compile(obj.args, obj)
- end,
- ['callargs'] = function (obj)
- for i = 1, #obj do
- Compile(obj[i], obj)
- end
- end,
- ['binary'] = function (obj)
- Compile(obj[1], obj)
- Compile(obj[2], obj)
- end,
- ['unary'] = function (obj)
- Compile(obj[1], obj)
- end,
- ['varargs'] = function (obj)
- local func = guide.getParentFunction(obj)
- if func then
- local index, vararg = guide.getFunctionVarArgs(func)
- if not index then
- pushError {
- type = 'UNEXPECT_DOTS',
- start = obj.start,
- finish = obj.finish,
+ Line = Line + 1
+ LineOffset = Tokens[Index] + #token
+ Index = Index + 2
+ State.lines[Line] = LineOffset
+ return true
+ end
+ return false
+end
+
+local function getSavePoint()
+ local index = Index
+ local line = Line
+ local lineOffset = LineOffset
+ local errs = State.errs
+ local errCount = #errs
+ return function ()
+ Index = index
+ Line = line
+ LineOffset = lineOffset
+ for i = errCount + 1, #errs do
+ errs[i] = nil
+ end
+ end
+end
+
+local function fastForwardToken(offset)
+ while true do
+ local myOffset = Tokens[Index]
+ if not myOffset
+ or myOffset >= offset then
+ break
+ end
+ local token = Tokens[Index + 1]
+ if NLMap[token] then
+ Line = Line + 1
+ LineOffset = Tokens[Index] + #token
+ State.lines[Line] = LineOffset
+ end
+ Index = Index + 2
+ end
+end
+
+local function resolveLongString(finishMark)
+ skipNL()
+ local miss
+ local start = Tokens[Index]
+ local finishOffset = sfind(Lua, finishMark, start, true)
+ if not finishOffset then
+ finishOffset = #Lua + 1
+ miss = true
+ end
+ local stringResult = start and ssub(Lua, start, finishOffset - 1) or ''
+ local lastLN = stringResult:find '[\r\n][^\r\n]*$'
+ if lastLN then
+ local result = stringResult
+ : gsub('\r\n?', '\n')
+ stringResult = result
+ end
+ fastForwardToken(finishOffset + #finishMark)
+ if miss then
+ local pos = getPosition(finishOffset - 1, 'right')
+ pushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = finishMark,
+ },
+ fix = {
+ title = 'ADD_LSTRING_END',
+ {
+ start = pos,
+ finish = pos,
+ text = finishMark,
}
- end
- if vararg then
- if not vararg.ref then
- vararg.ref = {}
- end
- vararg.ref[#vararg.ref+1] = obj
- obj.node = vararg
- end
- end
- end,
- ['paren'] = function (obj)
- Compile(obj.exp, obj)
- end,
- ['getindex'] = function (obj)
- Compile(obj.node, obj)
- Compile(obj.index, obj)
- end,
- ['setindex'] = function (obj)
- Compile(obj.node, obj)
- Compile(obj.index, obj)
- Compile(obj.value, obj)
- end,
- ['getmethod'] = function (obj)
- Compile(obj.node, obj)
- Compile(obj.method, obj)
- end,
- ['setmethod'] = function (obj)
- Compile(obj.node, obj)
- Compile(obj.method, obj)
- local value = obj.value
- local localself = {
- type = 'local',
- start = value.start,
- finish = value.start,
- method = obj,
- effect = obj.finish,
- tag = 'self',
- dummy = true,
- [1] = 'self',
+ },
}
- if not value.args then
- value.args = {
- type = 'funcargs',
- start = obj.start,
- finish = obj.finish,
- }
+ end
+ return stringResult, getPosition(finishOffset + #finishMark - 1, 'right')
+end
+
+local function parseLongString()
+ local start, finish, mark = sfind(Lua, '^(%[%=*%[)', Tokens[Index])
+ if not mark then
+ return nil
+ end
+ fastForwardToken(finish + 1)
+ local startPos = getPosition(start, 'left')
+ local finishMark = sgsub(mark, '%[', ']')
+ local stringResult, finishPos = resolveLongString(finishMark)
+ return {
+ type = 'string',
+ start = startPos,
+ finish = finishPos,
+ [1] = stringResult,
+ [2] = mark,
+ }
+end
+
+local function pushCommentHeadError(left)
+ if State.options.nonstandardSymbol['//'] then
+ return
+ end
+ pushError {
+ type = 'ERR_COMMENT_PREFIX',
+ start = left,
+ finish = left + 2,
+ fix = {
+ title = 'FIX_COMMENT_PREFIX',
+ {
+ start = left,
+ finish = left + 2,
+ text = '--',
+ },
+ }
+ }
+end
+
+local function pushLongCommentError(left, right)
+ if State.options.nonstandardSymbol['/**/'] then
+ return
+ end
+ pushError {
+ type = 'ERR_C_LONG_COMMENT',
+ start = left,
+ finish = right,
+ fix = {
+ title = 'FIX_C_LONG_COMMENT',
+ {
+ start = left,
+ finish = left + 2,
+ text = '--[[',
+ },
+ {
+ start = right - 2,
+ finish = right,
+ text = '--]]'
+ },
+ }
+ }
+end
+
+local function skipComment(isAction)
+ local token = Tokens[Index + 1]
+ if token == '--'
+ or (
+ token == '//'
+ and (
+ isAction
+ or State.options.nonstandardSymbol['//']
+ )
+ ) then
+ local start = Tokens[Index]
+ local left = getPosition(start, 'left')
+ local chead = false
+ if token == '//' then
+ chead = true
+ pushCommentHeadError(left)
end
- tableInsert(value.args, 1, localself)
- Compile(value, obj)
- end,
- ['function'] = function (obj)
- local lastBlock = Block
- local LastLocalCount = LocalCount
- Block = obj
- LocalCount = 0
- Compile(obj.args, obj)
- for i = 1, #obj do
- Compile(obj[i], obj)
- end
- Block = lastBlock
- LocalCount = LastLocalCount
- end,
- ['funcargs'] = function (obj)
- for i = 1, #obj do
- Compile(obj[i], obj)
- end
- end,
- ['table'] = function (obj)
- for i = 1, #obj do
- Compile(obj[i], obj)
- end
- end,
- ['tablefield'] = function (obj)
- Compile(obj.value, obj)
- end,
- ['tableindex'] = function (obj)
- Compile(obj.index, obj)
- Compile(obj.value, obj)
- end,
- ['tableexp'] = function (obj)
- Compile(obj.value, obj)
- end,
- ['index'] = function (obj)
- Compile(obj.index, obj)
- end,
- ['select'] = function (obj)
- local vararg = obj.vararg
- if vararg.parent then
- if not vararg.extParent then
- vararg.extParent = {}
- end
- vararg.extParent[#vararg.extParent+1] = obj
- else
- Compile(vararg, obj)
- end
- end,
- ['setname'] = function (obj)
- Compile(obj.value, obj)
- local loc = guide.getLocal(obj, obj[1], obj.start)
- if loc then
- obj.type = 'setlocal'
- obj.loc = loc
- addRef(loc, obj)
- if loc.attrs then
- local const
- for i = 1, #loc.attrs do
- local attr = loc.attrs[i][1]
- if attr == 'const'
- or attr == 'close' then
- const = true
- break
- end
- end
- if const then
- pushError {
- type = 'SET_CONST',
- start = obj.start,
- finish = obj.finish,
- }
- end
- end
- else
- obj.type = 'setglobal'
- local node = guide.getLocal(obj, ENVMode, obj.start)
- if node then
- addRef(node, obj)
- end
- local name = obj[1]
- if specials[name] then
- addSpecial(name, obj)
- elseif Options and Options.special then
- local asName = Options.special[name]
- if specials[asName] then
- addSpecial(asName, obj)
- end
- end
+ Index = Index + 2
+ local longComment = start + 2 == Tokens[Index] and parseLongString()
+ if longComment then
+ longComment.type = 'comment.long'
+ longComment.text = longComment[1]
+ longComment.mark = longComment[2]
+ longComment[1] = nil
+ longComment[2] = nil
+ State.comms[#State.comms+1] = longComment
+ return true
end
- end,
- ['local'] = function (obj)
- local attrs = obj.attrs
- if attrs then
- for i = 1, #attrs do
- Compile(attrs[i], obj)
+ while true do
+ local nl = Tokens[Index + 1]
+ if not nl or NLMap[nl] then
+ break
end
+ Index = Index + 2
end
- if Block then
- if not Block.locals then
- Block.locals = {}
- end
- Block.locals[#Block.locals+1] = obj
- LocalCount = LocalCount + 1
- if LocalCount > LocalLimit then
- pushError {
- type = 'LOCAL_LIMIT',
- start = obj.start,
- finish = obj.finish,
+ local right = Tokens[Index] and (Tokens[Index] - 1) or #Lua
+ State.comms[#State.comms+1] = {
+ type = chead and 'comment.cshort' or 'comment.short',
+ start = left,
+ finish = getPosition(right, 'right'),
+ text = ssub(Lua, start + 2, right),
+ }
+ return true
+ end
+ if token == '/*' then
+ local start = Tokens[Index]
+ local left = getPosition(start, 'left')
+ Index = Index + 2
+ local result, right = resolveLongString '*/'
+ pushLongCommentError(left, right)
+ State.comms[#State.comms+1] = {
+ type = 'comment.long',
+ start = left,
+ finish = right,
+ text = result,
+ }
+ return true
+ end
+ return false
+end
+
+local function skipSpace(isAction)
+ repeat until not skipNL()
+ and not skipComment(isAction)
+end
+
+local function expectAssign(isAction)
+ local token = Tokens[Index + 1]
+ if token == '=' then
+ Index = Index + 2
+ return true
+ end
+ if token == '==' then
+ local left = getPosition(Tokens[Index], 'left')
+ local right = getPosition(Tokens[Index] + #token - 1, 'right')
+ pushError {
+ type = 'ERR_ASSIGN_AS_EQ',
+ start = left,
+ finish = right,
+ fix = {
+ title = 'FIX_ASSIGN_AS_EQ',
+ {
+ start = left,
+ finish = right,
+ text = '=',
}
+ }
+ }
+ Index = Index + 2
+ return true
+ end
+ if isAction then
+ if token == '+='
+ or token == '-='
+ or token == '*='
+ or token == '/='
+ or token == '%='
+ or token == '^='
+ or token == '//='
+ or token == '|='
+ or token == '&='
+ or token == '>>='
+ or token == '<<=' then
+ if not State.options.nonstandardSymbol[token] then
+ unknownSymbol()
end
+ Index = Index + 2
+ return true
end
- if obj.localfunction then
- obj.localfunction = nil
+ end
+ return false
+end
+
+local function parseLocalAttrs()
+ local attrs
+ while true do
+ skipSpace()
+ local token = Tokens[Index + 1]
+ if token ~= '<' then
+ break
end
- Compile(obj.value, obj)
- if obj.value and obj.value.special then
- addSpecial(obj.value.special, obj)
+ if not attrs then
+ attrs = {
+ type = 'localattrs',
+ }
end
- end,
- ['setfield'] = function (obj)
- Compile(obj.node, obj)
- Compile(obj.value, obj)
- end,
- ['do'] = function (obj)
- local lastBlock = Block
- Block = obj
- CompileBlock(obj, obj)
- if Block.locals then
- LocalCount = LocalCount - #Block.locals
+ local attr = {
+ type = 'localattr',
+ parent = attrs,
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index], 'right'),
+ }
+ attrs[#attrs+1] = attr
+ Index = Index + 2
+ skipSpace()
+ local word, wstart, wfinish = peekWord()
+ if word then
+ attr[1] = word
+ attr.finish = wfinish
+ Index = Index + 2
+ if word ~= 'const'
+ and word ~= 'close' then
+ pushError {
+ type = 'UNKNOWN_ATTRIBUTE',
+ start = wstart,
+ finish = wfinish,
+ }
+ end
+ else
+ missName()
end
- Block = lastBlock
- end,
- ['return'] = function (obj)
- for i = 1, #obj do
- Compile(obj[i], obj)
+ attr.finish = lastRightPosition()
+ skipSpace()
+ if Tokens[Index + 1] == '>' then
+ attr.finish = getPosition(Tokens[Index], 'right')
+ Index = Index + 2
+ elseif Tokens[Index + 1] == '>=' then
+ attr.finish = getPosition(Tokens[Index], 'right')
+ pushError {
+ type = 'MISS_SPACE_BETWEEN',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + 1, 'right'),
+ }
+ Index = Index + 2
+ else
+ missSymbol '>'
end
- if Block and Block[#Block] ~= obj then
+ if State.version ~= 'Lua 5.4' then
pushError {
- type = 'ACTION_AFTER_RETURN',
- start = obj.start,
- finish = obj.finish,
+ type = 'UNSUPPORT_SYMBOL',
+ start = attr.start,
+ finish = attr.finish,
+ version = 'Lua 5.4',
+ info = {
+ version = State.version
+ }
}
end
- local func = guide.getParentFunction(obj)
- if func then
- if not func.returns then
- func.returns = {}
- end
- func.returns[#func.returns+1] = obj
+ end
+ return attrs
+end
+
+local function createLocal(obj, attrs)
+ obj.type = 'local'
+ obj.effect = obj.finish
+
+ if attrs then
+ obj.attrs = attrs
+ attrs.parent = obj
+ end
+
+ local chunk = Chunk[#Chunk]
+ if chunk then
+ local locals = chunk.locals
+ if not locals then
+ locals = {}
+ chunk.locals = locals
end
- end,
- ['label'] = function (obj)
- local block = guide.getBlock(obj)
- if block then
- if not block.labels then
- block.labels = {}
- end
- local name = obj[1]
- local label = guide.getLabel(block, name)
- if label then
- if Version == 'Lua 5.4'
- or block == guide.getBlock(label) then
- pushError {
- type = 'REDEFINED_LABEL',
- start = obj.start,
- finish = obj.finish,
- relative = {
- {
- label.start,
- label.finish,
- }
- }
- }
- end
- end
- block.labels[name] = obj
- end
- end,
- ['goto'] = function (obj)
- GoToTag[#GoToTag+1] = obj
- end,
- ['if'] = function (obj)
- for i = 1, #obj do
- Compile(obj[i], obj)
- end
- end,
- ['ifblock'] = function (obj)
- local lastBlock = Block
- Block = obj
- Compile(obj.filter, obj)
- CompileBlock(obj, obj)
- if Block.locals then
- LocalCount = LocalCount - #Block.locals
- end
- Block = lastBlock
- end,
- ['elseifblock'] = function (obj)
- local lastBlock = Block
- Block = obj
- Compile(obj.filter, obj)
- CompileBlock(obj, obj)
- if Block.locals then
- LocalCount = LocalCount - #Block.locals
- end
- Block = lastBlock
- end,
- ['elseblock'] = function (obj)
- local lastBlock = Block
- Block = obj
- CompileBlock(obj, obj)
- if Block.locals then
- LocalCount = LocalCount - #Block.locals
- end
- Block = lastBlock
- end,
- ['loop'] = function (obj)
- local lastBlock = Block
- Block = obj
- Compile(obj.loc, obj)
- Compile(obj.max, obj)
- Compile(obj.step, obj)
- CompileBlock(obj, obj)
- if Block.locals then
- LocalCount = LocalCount - #Block.locals
- end
- Block = lastBlock
- end,
- ['in'] = function (obj)
- local lastBlock = Block
- Block = obj
- local keys = obj.keys
- for i = 1, #keys do
- Compile(keys[i], obj)
- end
- CompileBlock(obj, obj)
- if Block.locals then
- LocalCount = LocalCount - #Block.locals
- end
- Block = lastBlock
- end,
- ['while'] = function (obj)
- local lastBlock = Block
- Block = obj
- Compile(obj.filter, obj)
- CompileBlock(obj, obj)
- if Block.locals then
- LocalCount = LocalCount - #Block.locals
- end
- Block = lastBlock
- end,
- ['repeat'] = function (obj)
- local lastBlock = Block
- Block = obj
- CompileBlock(obj, obj)
- Compile(obj.filter, obj)
- if Block.locals then
- LocalCount = LocalCount - #Block.locals
- end
- Block = lastBlock
- end,
- ['break'] = function (obj)
- local block = guide.getBreakBlock(obj)
- if block then
- if not block.breaks then
- block.breaks = {}
- end
- block.breaks[#block.breaks+1] = obj
- else
+ locals[#locals+1] = obj
+ LocalCount = LocalCount + 1
+ if LocalCount > LocalLimit then
pushError {
- type = 'BREAK_OUTSIDE',
+ type = 'LOCAL_LIMIT',
start = obj.start,
finish = obj.finish,
}
end
- end,
- ['main'] = function (obj)
- Block = obj
- Compile({
- type = 'local',
- start = 0,
- finish = 0,
- effect = 0,
- tag = '_ENV',
- special= '_G',
- [1] = ENVMode,
- }, obj)
- --- _ENV 是上值,不计入局部变量计数
- LocalCount = 0
- CompileBlock(obj, obj)
- Block = nil
- end,
-}
-
-function CompileBlock(obj, parent)
- for i = 1, #obj do
- local act = obj[i]
- act.parent = parent
- local f = vmMap[act.type]
- if f then
- f(act)
- end
end
+ return obj
end
-function Compile(obj, parent)
- if not obj then
- return nil
- end
- if Compiled[obj] then
- return
- end
- Compiled[obj] = true
- obj.parent = parent
- local f = vmMap[obj.type]
- if not f then
- return
- end
- f(obj)
+local function pushChunk(chunk)
+ Chunk[#Chunk+1] = chunk
end
-local function compileGoTo(obj)
- local name = obj[1]
- local label = guide.getLabel(obj, name)
- if not label then
- pushError {
- type = 'NO_VISIBLE_LABEL',
- start = obj.start,
- finish = obj.finish,
- info = {
- label = name,
- }
- }
- return
- end
+local function resolveLable(label, obj)
if not label.ref then
label.ref = {}
end
@@ -538,10 +766,10 @@ local function compileGoTo(obj)
local ref = refs[j]
if ref.finish > label.finish then
pushError {
- type = 'JUMP_LOCAL_SCOPE',
- start = obj.start,
- finish = obj.finish,
- info = {
+ type = 'JUMP_LOCAL_SCOPE',
+ start = obj.start,
+ finish = obj.finish,
+ info = {
loc = loc[1],
},
relative = {
@@ -562,42 +790,3129 @@ local function compileGoTo(obj)
end
end
-local function PostCompile()
- for i = 1, #GoToTag do
- compileGoTo(GoToTag[i])
+local function resolveGoTo(gotos)
+ for i = 1, #gotos do
+ local action = gotos[i]
+ local label = guide.getLabel(action, action[1])
+ if label then
+ resolveLable(label, action)
+ else
+ pushError {
+ type = 'NO_VISIBLE_LABEL',
+ start = action.start,
+ finish = action.finish,
+ info = {
+ label = action[1],
+ }
+ }
+ end
end
end
-return function (lua, mode, version, options)
- do
- local state, err = newparser(lua, mode, version, options)
- return state, err
- end
- local state, err = parse(lua, mode, version, options)
- if not state then
- return nil, err
- end
- --if options and options.delay then
- -- options.delay()
- --end
- local clock = os.clock()
- pushError = state.pushError
- ENVMode = state.ENVMode
- Compiled = {}
- GoToTag = {}
+local function popChunk()
+ local chunk = Chunk[#Chunk]
+ if chunk.gotos then
+ resolveGoTo(chunk.gotos)
+ chunk.gotos = nil
+ end
+ local lastAction = chunk[#chunk]
+ if lastAction then
+ chunk.finish = lastAction.finish
+ end
+ Chunk[#Chunk] = nil
+end
+
+local function parseNil()
+ if Tokens[Index + 1] ~= 'nil' then
+ return nil
+ end
+ local offset = Tokens[Index]
+ Index = Index + 2
+ return {
+ type = 'nil',
+ start = getPosition(offset, 'left'),
+ finish = getPosition(offset + 2, 'right'),
+ }
+end
+
+local function parseBoolean()
+ local word = Tokens[Index+1]
+ if word ~= 'true'
+ and word ~= 'false' then
+ return nil
+ end
+ local start = getPosition(Tokens[Index], 'left')
+ local finish = getPosition(Tokens[Index] + #word - 1, 'right')
+ Index = Index + 2
+ return {
+ type = 'boolean',
+ start = start,
+ finish = finish,
+ [1] = word == 'true' and true or false,
+ }
+end
+
+local function parseStringUnicode()
+ local offset = Tokens[Index] + 1
+ if ssub(Lua, offset, offset) ~= '{' then
+ local pos = getPosition(offset, 'left')
+ missSymbol('{', pos)
+ return nil, offset
+ end
+ local leftPos = getPosition(offset, 'left')
+ local x16 = smatch(Lua, '^%w*', offset + 1)
+ local rightPos = getPosition(offset + #x16, 'right')
+ offset = offset + #x16 + 1
+ if ssub(Lua, offset, offset) == '}' then
+ offset = offset + 1
+ rightPos = rightPos + 1
+ else
+ missSymbol('}', rightPos)
+ end
+ offset = offset + 1
+ if #x16 == 0 then
+ pushError {
+ type = 'UTF8_SMALL',
+ start = leftPos,
+ finish = rightPos,
+ }
+ return '', offset
+ end
+ if State.version ~= 'Lua 5.3'
+ and State.version ~= 'Lua 5.4'
+ and State.version ~= 'LuaJIT'
+ then
+ pushError {
+ type = 'ERR_ESC',
+ start = leftPos - 2,
+ finish = rightPos,
+ version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return nil, offset
+ end
+ local byte = tonumber(x16, 16)
+ if not byte then
+ for i = 1, #x16 do
+ if not tonumber(ssub(x16, i, i), 16) then
+ pushError {
+ type = 'MUST_X16',
+ start = leftPos + i,
+ finish = leftPos + i + 1,
+ }
+ end
+ end
+ return nil, offset
+ end
+ if State.version == 'Lua 5.4' then
+ if byte < 0 or byte > 0x7FFFFFFF then
+ pushError {
+ type = 'UTF8_MAX',
+ start = leftPos,
+ finish = rightPos,
+ info = {
+ min = '00000000',
+ max = '7FFFFFFF',
+ }
+ }
+ return nil, offset
+ end
+ else
+ if byte < 0 or byte > 0x10FFFF then
+ pushError {
+ type = 'UTF8_MAX',
+ start = leftPos,
+ finish = rightPos,
+ version = byte <= 0x7FFFFFFF and 'Lua 5.4' or nil,
+ info = {
+ min = '000000',
+ max = '10FFFF',
+ }
+ }
+ end
+ end
+ if byte >= 0 and byte <= 0x10FFFF then
+ return uchar(byte), offset
+ end
+ return '', offset
+end
+
+local stringPool = {}
+local function parseShortString()
+ local mark = Tokens[Index+1]
+ local startOffset = Tokens[Index]
+ local startPos = getPosition(startOffset, 'left')
+ Index = Index + 2
+ local stringIndex = 0
+ local currentOffset = startOffset + 1
+ local escs = {}
+ while true do
+ local token = Tokens[Index + 1]
+ if token == mark then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1)
+ Index = Index + 2
+ break
+ end
+ if NLMap[token] then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1)
+ missSymbol(mark)
+ break
+ end
+ if not token then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = ssub(Lua, currentOffset or -1)
+ missSymbol(mark)
+ break
+ end
+ if token == '\\' then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1)
+ currentOffset = Tokens[Index]
+ Index = Index + 2
+ if not Tokens[Index] then
+ goto CONTINUE
+ end
+ local escLeft = getPosition(currentOffset, 'left')
+ -- has space?
+ if Tokens[Index] - currentOffset > 1 then
+ local right = getPosition(currentOffset + 1, 'right')
+ pushError {
+ type = 'ERR_ESC',
+ start = escLeft,
+ finish = right,
+ }
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = right
+ escs[#escs+1] = 'err'
+ goto CONTINUE
+ end
+ local nextToken = ssub(Tokens[Index + 1], 1, 1)
+ if EscMap[nextToken] then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = EscMap[nextToken]
+ currentOffset = Tokens[Index] + #nextToken
+ Index = Index + 2
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = escLeft + 2
+ escs[#escs+1] = 'normal'
+ goto CONTINUE
+ end
+ if nextToken == mark then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = mark
+ currentOffset = Tokens[Index] + #nextToken
+ Index = Index + 2
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = escLeft + 2
+ escs[#escs+1] = 'normal'
+ goto CONTINUE
+ end
+ if nextToken == 'z' then
+ Index = Index + 2
+ repeat until not skipNL()
+ currentOffset = Tokens[Index]
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = escLeft + 2
+ escs[#escs+1] = 'normal'
+ goto CONTINUE
+ end
+ if CharMapNumber[nextToken] then
+ local numbers = smatch(Tokens[Index + 1], '^%d+')
+ if #numbers > 3 then
+ numbers = ssub(numbers, 1, 3)
+ end
+ currentOffset = Tokens[Index] + #numbers
+ fastForwardToken(currentOffset)
+ local right = getPosition(currentOffset - 1, 'right')
+ local byte = tointeger(numbers)
+ if byte and byte <= 255 then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = schar(byte)
+ else
+ pushError {
+ type = 'ERR_ESC',
+ start = escLeft,
+ finish = right,
+ }
+ end
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = right
+ escs[#escs+1] = 'byte'
+ goto CONTINUE
+ end
+ if nextToken == 'x' then
+ local left = getPosition(Tokens[Index] - 1, 'left')
+ local x16 = ssub(Tokens[Index + 1], 2, 3)
+ local byte = tonumber(x16, 16)
+ if byte then
+ currentOffset = Tokens[Index] + 3
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = schar(byte)
+ else
+ currentOffset = Tokens[Index] + 1
+ pushError {
+ type = 'MISS_ESC_X',
+ start = getPosition(currentOffset, 'left'),
+ finish = getPosition(currentOffset + 1, 'right'),
+ }
+ end
+ local right = getPosition(currentOffset + 1, 'right')
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = right
+ escs[#escs+1] = 'byte'
+ if State.version == 'Lua 5.1' then
+ pushError {
+ type = 'ERR_ESC',
+ start = left,
+ finish = left + 4,
+ version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ end
+ Index = Index + 2
+ goto CONTINUE
+ end
+ if nextToken == 'u' then
+ local str, newOffset = parseStringUnicode()
+ if str then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = str
+ end
+ currentOffset = newOffset
+ fastForwardToken(currentOffset - 1)
+ local right = getPosition(currentOffset + 1, 'right')
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = right
+ escs[#escs+1] = 'unicode'
+ goto CONTINUE
+ end
+ if NLMap[nextToken] then
+ stringIndex = stringIndex + 1
+ stringPool[stringIndex] = '\n'
+ currentOffset = Tokens[Index] + #nextToken
+ skipNL()
+ local right = getPosition(currentOffset + 1, 'right')
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = escLeft + 1
+ escs[#escs+1] = 'normal'
+ goto CONTINUE
+ end
+ local right = getPosition(currentOffset + 1, 'right')
+ pushError {
+ type = 'ERR_ESC',
+ start = escLeft,
+ finish = right,
+ }
+ escs[#escs+1] = escLeft
+ escs[#escs+1] = right
+ escs[#escs+1] = 'err'
+ end
+ Index = Index + 2
+ ::CONTINUE::
+ end
+ local stringResult = tconcat(stringPool, '', 1, stringIndex)
+ local str = {
+ type = 'string',
+ start = startPos,
+ finish = lastRightPosition(),
+ escs = #escs > 0 and escs or nil,
+ [1] = stringResult,
+ [2] = mark,
+ }
+ if mark == '`' then
+ if not State.options.nonstandardSymbol[mark] then
+ pushError {
+ type = 'ERR_NONSTANDARD_SYMBOL',
+ start = startPos,
+ finish = str.finish,
+ info = {
+ symbol = '"',
+ },
+ fix = {
+ title = 'FIX_NONSTANDARD_SYMBOL',
+ symbol = '"',
+ {
+ start = startPos,
+ finish = startPos + 1,
+ text = '"',
+ },
+ {
+ start = str.finish - 1,
+ finish = str.finish,
+ text = '"',
+ },
+ }
+ }
+ end
+ end
+ return str
+end
+
+local function parseString()
+ local c = Tokens[Index + 1]
+ if CharMapStrSH[c] then
+ return parseShortString()
+ end
+ if CharMapStrLH[c] then
+ return parseLongString()
+ end
+ return nil
+end
+
+local function parseNumber10(start)
+ local integer = true
+ local integerPart = smatch(Lua, '^%d*', start)
+ local offset = start + #integerPart
+ -- float part
+ if ssub(Lua, offset, offset) == '.' then
+ local floatPart = smatch(Lua, '^%d*', offset + 1)
+ integer = false
+ offset = offset + #floatPart + 1
+ end
+ -- exp part
+ local echar = ssub(Lua, offset, offset)
+ if CharMapE10[echar] then
+ integer = false
+ offset = offset + 1
+ local nextChar = ssub(Lua, offset, offset)
+ if CharMapSign[nextChar] then
+ offset = offset + 1
+ end
+ local exp = smatch(Lua, '^%d*', offset)
+ offset = offset + #exp
+ if #exp == 0 then
+ pushError {
+ type = 'MISS_EXPONENT',
+ start = getPosition(offset - 1, 'right'),
+ finish = getPosition(offset - 1, 'right'),
+ }
+ end
+ end
+ return tonumber(ssub(Lua, start, offset - 1)), offset, integer
+end
+
+local function parseNumber16(start)
+ local integerPart = smatch(Lua, '^[%da-fA-F]*', start)
+ local offset = start + #integerPart
+ local integer = true
+ -- float part
+ if ssub(Lua, offset, offset) == '.' then
+ local floatPart = smatch(Lua, '^[%da-fA-F]*', offset + 1)
+ integer = false
+ offset = offset + #floatPart + 1
+ if #integerPart == 0 and #floatPart == 0 then
+ pushError {
+ type = 'MUST_X16',
+ start = getPosition(offset - 1, 'right'),
+ finish = getPosition(offset - 1, 'right'),
+ }
+ end
+ else
+ if #integerPart == 0 then
+ pushError {
+ type = 'MUST_X16',
+ start = getPosition(offset - 1, 'right'),
+ finish = getPosition(offset - 1, 'right'),
+ }
+ return 0, offset
+ end
+ end
+ -- exp part
+ local echar = ssub(Lua, offset, offset)
+ if CharMapE16[echar] then
+ integer = false
+ offset = offset + 1
+ local nextChar = ssub(Lua, offset, offset)
+ if CharMapSign[nextChar] then
+ offset = offset + 1
+ end
+ local exp = smatch(Lua, '^%d*', offset)
+ offset = offset + #exp
+ end
+ local n = tonumber(ssub(Lua, start - 2, offset - 1))
+ return n, offset, integer
+end
+
+local function parseNumber2(start)
+ local bins = smatch(Lua, '^[01]*', start)
+ local offset = start + #bins
+ if State.version ~= 'LuaJIT' then
+ pushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = getPosition(start - 2, 'left'),
+ finish = getPosition(offset - 1, 'right'),
+ version = 'LuaJIT',
+ info = {
+ version = 'Lua 5.4',
+ }
+ }
+ end
+ return tonumber(bins, 2), offset
+end
+
+local function dropNumberTail(offset, integer)
+ local _, finish, word = sfind(Lua, '^([%.%w_\x80-\xff]+)', offset)
+ if not finish then
+ return offset
+ end
+ if integer then
+ if supper(ssub(word, 1, 2)) == 'LL' then
+ if State.version ~= 'LuaJIT' then
+ pushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = getPosition(offset, 'left'),
+ finish = getPosition(offset + 1, 'right'),
+ version = 'LuaJIT',
+ info = {
+ version = State.version,
+ }
+ }
+ end
+ offset = offset + 2
+ word = ssub(word, offset)
+ elseif supper(ssub(word, 1, 3)) == 'ULL' then
+ if State.version ~= 'LuaJIT' then
+ pushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = getPosition(offset, 'left'),
+ finish = getPosition(offset + 2, 'right'),
+ version = 'LuaJIT',
+ info = {
+ version = State.version,
+ }
+ }
+ end
+ offset = offset + 3
+ word = ssub(word, offset)
+ end
+ end
+ if supper(ssub(word, 1, 1)) == 'I' then
+ if State.version ~= 'LuaJIT' then
+ pushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = getPosition(offset, 'left'),
+ finish = getPosition(offset, 'right'),
+ version = 'LuaJIT',
+ info = {
+ version = State.version,
+ }
+ }
+ end
+ offset = offset + 1
+ word = ssub(word, offset)
+ end
+ if #word > 0 then
+ pushError {
+ type = 'MALFORMED_NUMBER',
+ start = getPosition(offset, 'left'),
+ finish = getPosition(finish, 'right'),
+ }
+ end
+ return finish + 1
+end
+
+local function parseNumber()
+ local offset = Tokens[Index]
+ if not offset then
+ return nil
+ end
+ local startPos = getPosition(offset, 'left')
+ local neg
+ if ssub(Lua, offset, offset) == '-' then
+ neg = true
+ offset = offset + 1
+ end
+ local number, integer
+ local firstChar = ssub(Lua, offset, offset)
+ if firstChar == '.' then
+ number, offset = parseNumber10(offset)
+ integer = false
+ elseif firstChar == '0' then
+ local nextChar = ssub(Lua, offset + 1, offset + 1)
+ if CharMapN16[nextChar] then
+ number, offset, integer = parseNumber16(offset + 2)
+ elseif CharMapN2[nextChar] then
+ number, offset = parseNumber2(offset + 2)
+ integer = true
+ else
+ number, offset, integer = parseNumber10(offset)
+ end
+ elseif CharMapNumber[firstChar] then
+ number, offset, integer = parseNumber10(offset)
+ else
+ return nil
+ end
+ if not number then
+ number = 0
+ end
+ if neg then
+ number = - number
+ end
+ local result = {
+ type = integer and 'integer' or 'number',
+ start = startPos,
+ finish = getPosition(offset - 1, 'right'),
+ [1] = number,
+ }
+ offset = dropNumberTail(offset, integer)
+ fastForwardToken(offset)
+ return result
+end
+
+local function isKeyWord(word)
+ if KeyWord[word] then
+ return true
+ end
+ if word == 'goto' then
+ return State.version ~= 'Lua 5.1'
+ end
+ return false
+end
+
+local function parseName(asAction)
+ local word = peekWord()
+ if not word then
+ return nil
+ end
+ if ChunkFinishMap[word] then
+ return nil
+ end
+ if asAction and ChunkStartMap[word] then
+ return nil
+ end
+ local startPos = getPosition(Tokens[Index], 'left')
+ local finishPos = getPosition(Tokens[Index] + #word - 1, 'right')
+ Index = Index + 2
+ if not State.options.unicodeName and word:find '[\x80-\xff]' then
+ pushError {
+ type = 'UNICODE_NAME',
+ start = startPos,
+ finish = finishPos,
+ }
+ end
+ if isKeyWord(word) then
+ pushError {
+ type = 'KEYWORD',
+ start = startPos,
+ finish = finishPos,
+ }
+ end
+ return {
+ type = 'name',
+ start = startPos,
+ finish = finishPos,
+ [1] = word,
+ }
+end
+
+local function parseNameOrList(parent)
+ local first = parseName()
+ if not first then
+ return nil
+ end
+ skipSpace()
+ local list
+ while true do
+ if Tokens[Index + 1] ~= ',' then
+ break
+ end
+ Index = Index + 2
+ skipSpace()
+ local name = parseName(true)
+ if not name then
+ missName()
+ break
+ end
+ if not list then
+ list = {
+ type = 'list',
+ start = first.start,
+ finish = first.finish,
+ parent = parent,
+ [1] = first
+ }
+ end
+ list[#list+1] = name
+ list.finish = name.finish
+ end
+ return list or first
+end
+
+local function dropTail()
+ local token = Tokens[Index + 1]
+ if token ~= '?'
+ and token ~= ':' then
+ return
+ end
+ local pl, pt, pp = 0, 0, 0
+ while true do
+ local token = Tokens[Index + 1]
+ if not token then
+ break
+ end
+ if NLMap[token] then
+ break
+ end
+ if token == ',' then
+ if pl > 0
+ or pt > 0
+ or pp > 0 then
+ goto CONTINUE
+ else
+ break
+ end
+ end
+ if token == '<' then
+ pl = pl + 1
+ goto CONTINUE
+ end
+ if token == '{' then
+ pt = pt + 1
+ goto CONTINUE
+ end
+ if token == '(' then
+ pp = pp + 1
+ goto CONTINUE
+ end
+ if token == '>' then
+ if pl <= 0 then
+ break
+ end
+ pl = pl - 1
+ goto CONTINUE
+ end
+ if token == '}' then
+ if pt <= 0 then
+ break
+ end
+ pt = pt - 1
+ goto CONTINUE
+ end
+ if token == ')' then
+ if pp <= 0 then
+ break
+ end
+ pp = pp - 1
+ goto CONTINUE
+ end
+ ::CONTINUE::
+ Index = Index + 2
+ end
+end
+
+local function parseExpList(mini)
+ local list
+ local wantSep = false
+ while true do
+ skipSpace()
+ local token = Tokens[Index + 1]
+ if not token then
+ break
+ end
+ if ListFinishMap[token] then
+ break
+ end
+ if token == ',' then
+ local sepPos = getPosition(Tokens[Index], 'right')
+ if not wantSep then
+ pushError {
+ type = 'UNEXPECT_SYMBOL',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = sepPos,
+ info = {
+ symbol = ',',
+ }
+ }
+ end
+ wantSep = false
+ Index = Index + 2
+ goto CONTINUE
+ else
+ if mini then
+ if wantSep then
+ break
+ end
+ local nextToken = peekWord()
+ if isKeyWord(nextToken)
+ and nextToken ~= 'function'
+ and nextToken ~= 'true'
+ and nextToken ~= 'false'
+ and nextToken ~= 'nil'
+ and nextToken ~= 'not' then
+ break
+ end
+ end
+ local exp = parseExp()
+ if not exp then
+ break
+ end
+ dropTail()
+ if wantSep then
+ missSymbol(',', list[#list].finish, exp.start)
+ end
+ wantSep = true
+ if not list then
+ list = {
+ type = 'list',
+ start = exp.start,
+ }
+ end
+ list[#list+1] = exp
+ list.finish = exp.finish
+ exp.parent = list
+ end
+ ::CONTINUE::
+ end
+ if not list then
+ return nil
+ end
+ if not wantSep then
+ missExp()
+ end
+ return list
+end
+
+local function parseIndex()
+ local start = getPosition(Tokens[Index], 'left')
+ Index = Index + 2
+ skipSpace()
+ local exp = parseExp()
+ local index = {
+ type = 'index',
+ start = start,
+ finish = exp and exp.finish or (start + 1),
+ index = exp
+ }
+ if exp then
+ exp.parent = index
+ else
+ missExp()
+ end
+ skipSpace()
+ if Tokens[Index + 1] == ']' then
+ index.finish = getPosition(Tokens[Index], 'right')
+ Index = Index + 2
+ else
+ missSymbol ']'
+ end
+ return index
+end
+
+local function parseTable()
+ local tbl = {
+ type = 'table',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index], 'right'),
+ }
+ Index = Index + 2
+ local index = 0
+ local tindex = 0
+ local wantSep = false
+ while true do
+ skipSpace(true)
+ local token = Tokens[Index + 1]
+ if token == '}' then
+ Index = Index + 2
+ break
+ end
+ if CharMapTSep[token] then
+ if not wantSep then
+ missExp()
+ end
+ wantSep = false
+ Index = Index + 2
+ goto CONTINUE
+ end
+ local lastRight = lastRightPosition()
+
+ if peekWord() then
+ local savePoint = getSavePoint()
+ local name = parseName()
+ if name then
+ skipSpace()
+ if Tokens[Index + 1] == '=' then
+ Index = Index + 2
+ if wantSep then
+ pushError {
+ type = 'MISS_SEP_IN_TABLE',
+ start = lastRight,
+ finish = getPosition(Tokens[Index], 'left'),
+ }
+ end
+ wantSep = true
+ skipSpace()
+ local fvalue = parseExp()
+ local tfield = {
+ type = 'tablefield',
+ start = name.start,
+ finish = name.finish,
+ range = fvalue and fvalue.finish,
+ node = tbl,
+ parent = tbl,
+ field = name,
+ value = fvalue,
+ }
+ name.type = 'field'
+ name.parent = tfield
+ if fvalue then
+ fvalue.parent = tfield
+ else
+ missExp()
+ end
+ index = index + 1
+ tbl[index] = tfield
+ goto CONTINUE
+ end
+ end
+ savePoint()
+ end
+
+ local exp = parseExp(true)
+ if exp then
+ if wantSep then
+ pushError {
+ type = 'MISS_SEP_IN_TABLE',
+ start = lastRight,
+ finish = exp.start,
+ }
+ end
+ wantSep = true
+ if exp.type == 'varargs' then
+ index = index + 1
+ tbl[index] = exp
+ exp.parent = tbl
+ goto CONTINUE
+ end
+ index = index + 1
+ tindex = tindex + 1
+ local texp = {
+ type = 'tableexp',
+ start = exp.start,
+ finish = exp.finish,
+ tindex = tindex,
+ parent = tbl,
+ value = exp,
+ }
+ exp.parent = texp
+ tbl[index] = texp
+ goto CONTINUE
+ end
+
+ if token == '[' then
+ if wantSep then
+ pushError {
+ type = 'MISS_SEP_IN_TABLE',
+ start = lastRight,
+ finish = getPosition(Tokens[Index], 'left'),
+ }
+ end
+ wantSep = true
+ local tindex = parseIndex()
+ skipSpace()
+ tindex.type = 'tableindex'
+ tindex.node = tbl
+ tindex.parent = tbl
+ index = index + 1
+ tbl[index] = tindex
+ if expectAssign() then
+ skipSpace()
+ local ivalue = parseExp()
+ if ivalue then
+ ivalue.parent = tindex
+ tindex.range = ivalue.finish
+ tindex.value = ivalue
+ else
+ missExp()
+ end
+ else
+ missSymbol '='
+ end
+ goto CONTINUE
+ end
+
+ missSymbol '}'
+ break
+ ::CONTINUE::
+ end
+ tbl.finish = lastRightPosition()
+ return tbl
+end
+
+local function addDummySelf(node, call)
+ if node.type ~= 'getmethod' then
+ return
+ end
+ -- dummy param `self`
+ if not call.args then
+ call.args = {
+ type = 'callargs',
+ start = call.start,
+ finish = call.finish,
+ parent = call,
+ }
+ end
+ local self = {
+ type = 'self',
+ start = node.colon.start,
+ finish = node.colon.finish,
+ parent = call.args,
+ [1] = 'self',
+ }
+ tinsert(call.args, 1, self)
+end
+
+local function checkAmbiguityCall(call, parenPos)
+ if State.version ~= 'Lua 5.1' then
+ return
+ end
+ local node = call.node
+ local nodeRow = guide.rowColOf(node.finish)
+ local callRow = guide.rowColOf(parenPos)
+ if nodeRow == callRow then
+ return
+ end
+ pushError {
+ type = 'AMBIGUOUS_SYNTAX',
+ start = parenPos,
+ finish = call.finish,
+ }
+end
+
+local function parseSimple(node, funcName)
+ local lastMethod
+ while true do
+ if lastMethod and node.node == lastMethod then
+ if node.type ~= 'call' then
+ missSymbol('(', node.node.finish, node.node.finish)
+ end
+ lastMethod = nil
+ end
+ skipSpace()
+ local token = Tokens[Index + 1]
+ if token == '.' then
+ local dot = {
+ type = token,
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index], 'right'),
+ }
+ Index = Index + 2
+ skipSpace()
+ local field = parseName(true)
+ local getfield = {
+ type = 'getfield',
+ start = node.start,
+ finish = lastRightPosition(),
+ node = node,
+ dot = dot,
+ field = field
+ }
+ if field then
+ field.parent = getfield
+ field.type = 'field'
+ else
+ pushError {
+ type = 'MISS_FIELD',
+ start = lastRightPosition(),
+ finish = lastRightPosition(),
+ }
+ end
+ node.parent = getfield
+ node.next = getfield
+ node = getfield
+ elseif token == ':' then
+ local colon = {
+ type = token,
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index], 'right'),
+ }
+ Index = Index + 2
+ skipSpace()
+ local method = parseName(true)
+ local getmethod = {
+ type = 'getmethod',
+ start = node.start,
+ finish = lastRightPosition(),
+ node = node,
+ colon = colon,
+ method = method
+ }
+ if method then
+ method.parent = getmethod
+ method.type = 'method'
+ else
+ pushError {
+ type = 'MISS_METHOD',
+ start = lastRightPosition(),
+ finish = lastRightPosition(),
+ }
+ end
+ node.parent = getmethod
+ node.next = getmethod
+ node = getmethod
+ if lastMethod then
+ missSymbol('(', node.node.finish, node.node.finish)
+ end
+ lastMethod = getmethod
+ elseif token == '(' then
+ if funcName then
+ break
+ end
+ local startPos = getPosition(Tokens[Index], 'left')
+ local call = {
+ type = 'call',
+ start = node.start,
+ node = node,
+ }
+ Index = Index + 2
+ local args = parseExpList()
+ if Tokens[Index + 1] == ')' then
+ call.finish = getPosition(Tokens[Index], 'right')
+ Index = Index + 2
+ else
+ call.finish = lastRightPosition()
+ missSymbol ')'
+ end
+ if args then
+ args.type = 'callargs'
+ args.start = startPos
+ args.finish = call.finish
+ args.parent = call
+ call.args = args
+ end
+ addDummySelf(node, call)
+ checkAmbiguityCall(call, startPos)
+ node.parent = call
+ node = call
+ elseif token == '{' then
+ if funcName then
+ break
+ end
+ local tbl = parseTable()
+ local call = {
+ type = 'call',
+ start = node.start,
+ finish = tbl.finish,
+ node = node,
+ }
+ local args = {
+ type = 'callargs',
+ start = tbl.start,
+ finish = tbl.finish,
+ parent = call,
+ [1] = tbl,
+ }
+ call.args = args
+ addDummySelf(node, call)
+ tbl.parent = args
+ node.parent = call
+ node = call
+ elseif CharMapStrSH[token] then
+ if funcName then
+ break
+ end
+ local str = parseShortString()
+ local call = {
+ type = 'call',
+ start = node.start,
+ finish = str.finish,
+ node = node,
+ }
+ local args = {
+ type = 'callargs',
+ start = str.start,
+ finish = str.finish,
+ parent = call,
+ [1] = str,
+ }
+ call.args = args
+ addDummySelf(node, call)
+ str.parent = args
+ node.parent = call
+ node = call
+ elseif CharMapStrLH[token] then
+ local str = parseLongString()
+ if str then
+ if funcName then
+ break
+ end
+ local call = {
+ type = 'call',
+ start = node.start,
+ finish = str.finish,
+ node = node,
+ }
+ local args = {
+ type = 'callargs',
+ start = str.start,
+ finish = str.finish,
+ parent = call,
+ [1] = str,
+ }
+ call.args = args
+ addDummySelf(node, call)
+ str.parent = args
+ node.parent = call
+ node = call
+ else
+ local index = parseIndex()
+ local bstart = index.start
+ index.type = 'getindex'
+ index.start = node.start
+ index.node = node
+ node.next = index
+ node.parent = index
+ node = index
+ if funcName then
+ pushError {
+ type = 'INDEX_IN_FUNC_NAME',
+ start = bstart,
+ finish = index.finish,
+ }
+ end
+ end
+ else
+ break
+ end
+ end
+ if node.type == 'call'
+ and node.node == lastMethod then
+ lastMethod = nil
+ end
+ if node == lastMethod then
+ if funcName then
+ lastMethod = nil
+ end
+ end
+ if lastMethod then
+ missSymbol('(', lastMethod.finish)
+ end
+ return node
+end
+
+local function parseVarargs()
+ local varargs = {
+ type = 'varargs',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + 2, 'right'),
+ }
+ Index = Index + 2
+ for i = #Chunk, 1, -1 do
+ local chunk = Chunk[i]
+ if chunk.vararg then
+ if not chunk.vararg.ref then
+ chunk.vararg.ref = {}
+ end
+ chunk.vararg.ref[#chunk.vararg.ref+1] = varargs
+ varargs.node = chunk.vararg
+ break
+ end
+ if chunk.type == 'main' then
+ break
+ end
+ if chunk.type == 'function' then
+ pushError {
+ type = 'UNEXPECT_DOTS',
+ start = varargs.start,
+ finish = varargs.finish,
+ }
+ break
+ end
+ end
+ return varargs
+end
+
+local function parseParen()
+ local pl = Tokens[Index]
+ local paren = {
+ type = 'paren',
+ start = getPosition(pl, 'left'),
+ finish = getPosition(pl, 'right')
+ }
+ Index = Index + 2
+ skipSpace()
+ local exp = parseExp()
+ if exp then
+ paren.exp = exp
+ paren.finish = exp.finish
+ exp.parent = paren
+ else
+ missExp()
+ end
+ skipSpace()
+ if Tokens[Index + 1] == ')' then
+ paren.finish = getPosition(Tokens[Index], 'right')
+ Index = Index + 2
+ else
+ missSymbol ')'
+ end
+ return paren
+end
+
+local function getLocal(name, pos)
+ for i = #Chunk, 1, -1 do
+ local chunk = Chunk[i]
+ local locals = chunk.locals
+ if locals then
+ local res
+ for n = 1, #locals do
+ local loc = locals[n]
+ if loc.effect > pos then
+ break
+ end
+ if loc[1] == name then
+ if not res or res.effect < loc.effect then
+ res = loc
+ end
+ end
+ end
+ if res then
+ return res
+ end
+ end
+ end
+end
+
+local function resolveName(node)
+ if not node then
+ return nil
+ end
+ local loc = getLocal(node[1], node.start)
+ if loc then
+ node.type = 'getlocal'
+ node.node = loc
+ if not loc.ref then
+ loc.ref = {}
+ end
+ loc.ref[#loc.ref+1] = node
+ if loc.special then
+ addSpecial(loc.special, node)
+ end
+ else
+ node.type = 'getglobal'
+ local env = getLocal(State.ENVMode, node.start)
+ if env then
+ node.node = env
+ if not env.ref then
+ env.ref = {}
+ end
+ env.ref[#env.ref+1] = node
+ end
+ end
+ local name = node[1]
+ if Specials[name] then
+ addSpecial(name, node)
+ else
+ local ospeicals = State.options.special
+ if ospeicals and ospeicals[name] then
+ addSpecial(ospeicals[name], node)
+ end
+ end
+ return node
+end
+
+local function isChunkFinishToken(token)
+ local currentChunk = Chunk[#Chunk]
+ if not currentChunk then
+ return false
+ end
+ local tp = currentChunk.type
+ if tp == 'main' then
+ return false
+ end
+ if tp == 'for'
+ or tp == 'in'
+ or tp == 'loop'
+ or tp == 'function' then
+ return token == 'end'
+ end
+ if tp == 'if'
+ or tp == 'ifblock'
+ or tp == 'elseifblock'
+ or tp == 'elseblock' then
+ return token == 'then'
+ or token == 'end'
+ or token == 'else'
+ or token == 'elseif'
+ end
+ if tp == 'repeat' then
+ return token == 'until'
+ end
+ return true
+end
+
+local function parseActions()
+ local rtn, last
+ while true do
+ skipSpace(true)
+ local token = Tokens[Index + 1]
+ if token == ';' then
+ Index = Index + 2
+ goto CONTINUE
+ end
+ if ChunkFinishMap[token]
+ and isChunkFinishToken(token) then
+ break
+ end
+ local action, failed = parseAction()
+ if failed then
+ if not skipUnknownSymbol() then
+ break
+ end
+ end
+ if action then
+ if not rtn and action.type == 'return' then
+ rtn = action
+ end
+ last = action
+ end
+ ::CONTINUE::
+ end
+ if rtn and rtn ~= last then
+ pushError {
+ type = 'ACTION_AFTER_RETURN',
+ start = rtn.start,
+ finish = rtn.finish,
+ }
+ end
+end
+
+local function parseParams(params)
+ local lastSep
+ local hasDots
+ while true do
+ skipSpace()
+ local token = Tokens[Index + 1]
+ if not token or token == ')' then
+ if lastSep then
+ missName()
+ end
+ break
+ end
+ if token == ',' then
+ if lastSep or lastSep == nil then
+ missName()
+ else
+ lastSep = true
+ end
+ Index = Index + 2
+ goto CONTINUE
+ end
+ if token == '...' then
+ if lastSep == false then
+ missSymbol ','
+ end
+ lastSep = false
+ if not params then
+ params = {}
+ end
+ local vararg = {
+ type = '...',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + 2, 'right'),
+ parent = params,
+ [1] = '...',
+ }
+ local chunk = Chunk[#Chunk]
+ chunk.vararg = vararg
+ params[#params+1] = vararg
+ if hasDots then
+ pushError {
+ type = 'ARGS_AFTER_DOTS',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + 2, 'right'),
+ }
+ end
+ hasDots = true
+ Index = Index + 2
+ goto CONTINUE
+ end
+ if CharMapWord[ssub(token, 1, 1)] then
+ if lastSep == false then
+ missSymbol ','
+ end
+ lastSep = false
+ if not params then
+ params = {}
+ end
+ params[#params+1] = createLocal {
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + #token - 1, 'right'),
+ parent = params,
+ [1] = token,
+ }
+ if hasDots then
+ pushError {
+ type = 'ARGS_AFTER_DOTS',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + #token - 1, 'right'),
+ }
+ end
+ if isKeyWord(token) then
+ pushError {
+ type = 'KEYWORD',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + #token - 1, 'right'),
+ }
+ end
+ Index = Index + 2
+ goto CONTINUE
+ end
+ skipUnknownSymbol '%,%)%.'
+ ::CONTINUE::
+ end
+ return params
+end
+
+local function parseFunction(isLocal, isAction)
+ local funcLeft = getPosition(Tokens[Index], 'left')
+ local funcRight = getPosition(Tokens[Index] + 7, 'right')
+ local func = {
+ type = 'function',
+ start = funcLeft,
+ finish = funcRight,
+ keyword = {
+ [1] = funcLeft,
+ [2] = funcRight,
+ },
+ }
+ Index = Index + 2
+ local LastLocalCount = LocalCount
LocalCount = 0
- Version = version
- Root = state.ast
- if Root then
- Root.state = state
- end
- Options = options
- if type(state.ast) == 'table' then
- Compile(state.ast)
- end
- PostCompile()
- state.compileClock = os.clock() - clock
- Compiled = nil
- GoToTag = nil
- return state
+ skipSpace(true)
+ local hasLeftParen = Tokens[Index + 1] == '('
+ if not hasLeftParen then
+ local name = parseName()
+ if name then
+ local simple = parseSimple(name, true)
+ if isLocal then
+ if simple == name then
+ createLocal(name)
+ else
+ resolveName(name)
+ pushError {
+ type = 'UNEXPECT_LFUNC_NAME',
+ start = simple.start,
+ finish = simple.finish,
+ }
+ end
+ else
+ resolveName(name)
+ end
+ func.name = simple
+ func.finish = simple.finish
+ if not isAction then
+ simple.parent = func
+ pushError {
+ type = 'UNEXPECT_EFUNC_NAME',
+ start = simple.start,
+ finish = simple.finish,
+ }
+ end
+ skipSpace(true)
+ hasLeftParen = Tokens[Index + 1] == '('
+ end
+ end
+ pushChunk(func)
+ local params
+ if func.name and func.name.type == 'getmethod' then
+ if func.name.type == 'getmethod' then
+ params = {}
+ params[1] = createLocal {
+ start = funcRight,
+ finish = funcRight,
+ parent = params,
+ [1] = 'self',
+ }
+ params[1].type = 'self'
+ end
+ end
+ if hasLeftParen then
+ local parenLeft = getPosition(Tokens[Index], 'left')
+ Index = Index + 2
+ params = parseParams(params)
+ if params then
+ params.type = 'funcargs'
+ params.start = parenLeft
+ params.finish = lastRightPosition()
+ params.parent = func
+ func.args = params
+ end
+ skipSpace(true)
+ if Tokens[Index + 1] == ')' then
+ local parenRight = getPosition(Tokens[Index], 'right')
+ func.finish = parenRight
+ if params then
+ params.finish = parenRight
+ end
+ Index = Index + 2
+ skipSpace(true)
+ else
+ func.finish = lastRightPosition()
+ if params then
+ params.finish = func.finish
+ end
+ missSymbol ')'
+ end
+ else
+ missSymbol '('
+ end
+ parseActions()
+ popChunk()
+ if Tokens[Index + 1] == 'end' then
+ local endLeft = getPosition(Tokens[Index], 'left')
+ local endRight = getPosition(Tokens[Index] + 2, 'right')
+ func.keyword[3] = endLeft
+ func.keyword[4] = endRight
+ func.finish = endRight
+ Index = Index + 2
+ else
+ func.finish = lastRightPosition()
+ missEnd(funcLeft, funcRight)
+ end
+ LocalCount = LastLocalCount
+ return func
+end
+
+local function parseExpUnit()
+ local token = Tokens[Index + 1]
+ if token == '(' then
+ local paren = parseParen()
+ return parseSimple(paren, false)
+ end
+
+ if token == '...' then
+ local varargs = parseVarargs()
+ return varargs
+ end
+
+ if token == '{' then
+ local table = parseTable()
+ return table
+ end
+
+ if CharMapStrSH[token] then
+ local string = parseShortString()
+ return string
+ end
+
+ if CharMapStrLH[token] then
+ local string = parseLongString()
+ return string
+ end
+
+ local number = parseNumber()
+ if number then
+ return number
+ end
+
+ if ChunkFinishMap[token] then
+ return nil
+ end
+
+ if token == 'nil' then
+ return parseNil()
+ end
+
+ if token == 'true'
+ or token == 'false' then
+ return parseBoolean()
+ end
+
+ if token == 'function' then
+ return parseFunction()
+ end
+
+ local node = parseName()
+ if node then
+ return parseSimple(resolveName(node), false)
+ end
+
+ return nil
+end
+
+local function parseUnaryOP()
+ local token = Tokens[Index + 1]
+ local symbol = UnarySymbol[token] and token or UnaryAlias[token]
+ if not symbol then
+ return nil
+ end
+ local myLevel = UnarySymbol[symbol]
+ local op = {
+ type = symbol,
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + #symbol - 1, 'right'),
+ }
+ Index = Index + 2
+ return op, myLevel
+end
+
+---@param level integer # op level must greater than this level
+local function parseBinaryOP(asAction, level)
+ local token = Tokens[Index + 1]
+ local symbol = (BinarySymbol[token] and token)
+ or BinaryAlias[token]
+ or (not asAction and BinaryActionAlias[token])
+ if not symbol then
+ return nil
+ end
+ if symbol == '//' and State.options.nonstandardSymbol['//'] then
+ return nil
+ end
+ local myLevel = BinarySymbol[symbol]
+ if level and myLevel < level then
+ return nil
+ end
+ local op = {
+ type = symbol,
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + #token - 1, 'right'),
+ }
+ if not asAction then
+ if token == '=' then
+ pushError {
+ type = 'ERR_EQ_AS_ASSIGN',
+ start = op.start,
+ finish = op.finish,
+ fix = {
+ title = 'FIX_EQ_AS_ASSIGN',
+ {
+ start = op.start,
+ finish = op.finish,
+ text = '==',
+ }
+ }
+ }
+ end
+ end
+ if BinaryAlias[token] then
+ if not State.options.nonstandardSymbol[token] then
+ pushError {
+ type = 'ERR_NONSTANDARD_SYMBOL',
+ start = op.start,
+ finish = op.finish,
+ info = {
+ symbol = symbol,
+ },
+ fix = {
+ title = 'FIX_NONSTANDARD_SYMBOL',
+ symbol = symbol,
+ {
+ start = op.start,
+ finish = op.finish,
+ text = symbol,
+ },
+ }
+ }
+ end
+ end
+ if token == '//'
+ or token == '<<'
+ or token == '>>' then
+ if State.version ~= 'Lua 5.3'
+ and State.version ~= 'Lua 5.4' then
+ pushError {
+ type = 'UNSUPPORT_SYMBOL',
+ version = {'Lua 5.3', 'Lua 5.4'},
+ start = op.start,
+ finish = op.finish,
+ info = {
+ version = State.version,
+ }
+ }
+ end
+ end
+ Index = Index + 2
+ return op, myLevel
+end
+
+function parseExp(asAction, level)
+ local exp
+ local uop, uopLevel = parseUnaryOP()
+ if uop then
+ skipSpace()
+ local child = parseExp(asAction, uopLevel)
+ -- 预计算负数
+ if uop.type == '-'
+ and child
+ and (child.type == 'number' or child.type == 'integer') then
+ child.start = uop.start
+ child[1] = - child[1]
+ exp = child
+ else
+ exp = {
+ type = 'unary',
+ op = uop,
+ start = uop.start,
+ finish = child and child.finish or uop.finish,
+ [1] = child,
+ }
+ if child then
+ child.parent = exp
+ else
+ missExp()
+ end
+ end
+ else
+ exp = parseExpUnit()
+ if not exp then
+ return nil
+ end
+ end
+
+ while true do
+ skipSpace()
+ local bop, bopLevel = parseBinaryOP(asAction, level)
+ if not bop then
+ break
+ end
+
+ ::AGAIN::
+ skipSpace()
+ local isForward = SymbolForward[bopLevel]
+ local child = parseExp(asAction, isForward and (bopLevel + 0.5) or bopLevel)
+ if not child then
+ if skipUnknownSymbol() then
+ goto AGAIN
+ else
+ missExp()
+ end
+ end
+ local bin = {
+ type = 'binary',
+ start = exp.start,
+ finish = child and child.finish or bop.finish,
+ op = bop,
+ [1] = exp,
+ [2] = child
+ }
+ exp.parent = bin
+ if child then
+ child.parent = bin
+ end
+ exp = bin
+ end
+
+ return exp
+end
+
+local function skipSeps()
+ while true do
+ skipSpace()
+ if Tokens[Index + 1] == ',' then
+ missExp()
+ Index = Index + 2
+ else
+ break
+ end
+ end
+end
+
+---@return parser.object? first
+---@return parser.object? second
+---@return parser.object[]? rest
+local function parseSetValues()
+ skipSpace()
+ local first = parseExp()
+ if not first then
+ return nil
+ end
+ skipSpace()
+ if Tokens[Index + 1] ~= ',' then
+ return first
+ end
+ Index = Index + 2
+ skipSeps()
+ local second = parseExp()
+ if not second then
+ missExp()
+ return first
+ end
+ skipSpace()
+ if Tokens[Index + 1] ~= ',' then
+ return first, second
+ end
+ Index = Index + 2
+ skipSeps()
+ local third = parseExp()
+ if not third then
+ missExp()
+ return first, second
+ end
+
+ local rest = { third }
+ while true do
+ skipSpace()
+ if Tokens[Index + 1] ~= ',' then
+ return first, second, rest
+ end
+ Index = Index + 2
+ skipSeps()
+ local exp = parseExp()
+ if not exp then
+ missExp()
+ return first, second, rest
+ end
+ rest[#rest+1] = exp
+ end
+end
+
+local function pushActionIntoCurrentChunk(action)
+ local chunk = Chunk[#Chunk]
+ if chunk then
+ chunk[#chunk+1] = action
+ action.parent = chunk
+ end
+end
+
+---@return parser.object? second
+---@return parser.object[]? rest
+local function parseVarTails(parser, isLocal)
+ if Tokens[Index + 1] ~= ',' then
+ return nil
+ end
+ Index = Index + 2
+ skipSpace()
+ local second = parser(true)
+ if not second then
+ missName()
+ return nil
+ end
+ if isLocal then
+ createLocal(second, parseLocalAttrs())
+ end
+ skipSpace()
+ if Tokens[Index + 1] ~= ',' then
+ return second
+ end
+ Index = Index + 2
+ skipSeps()
+ local third = parser(true)
+ if not third then
+ missName()
+ return second
+ end
+ if isLocal then
+ createLocal(third, parseLocalAttrs())
+ end
+ local rest = { third }
+ while true do
+ skipSpace()
+ if Tokens[Index + 1] ~= ',' then
+ return second, rest
+ end
+ Index = Index + 2
+ skipSeps()
+ local name = parser(true)
+ if not name then
+ missName()
+ return second, rest
+ end
+ if isLocal then
+ createLocal(name, parseLocalAttrs())
+ end
+ rest[#rest+1] = name
+ end
+end
+
+local function bindValue(n, v, index, lastValue, isLocal, isSet)
+ if isLocal then
+ if v and v.special then
+ addSpecial(v.special, n)
+ end
+ elseif isSet then
+ n.type = GetToSetMap[n.type] or n.type
+ if n.type == 'setlocal' then
+ local loc = n.node
+ if loc.attrs then
+ pushError {
+ type = 'SET_CONST',
+ start = n.start,
+ finish = n.finish,
+ }
+ end
+ end
+ end
+ if not v and lastValue then
+ if lastValue.type == 'call'
+ or lastValue.type == 'varargs' then
+ v = lastValue
+ if not v.extParent then
+ v.extParent = {}
+ end
+ end
+ end
+ if v then
+ if v.type == 'call'
+ or v.type == 'varargs' then
+ local select = {
+ type = 'select',
+ sindex = index,
+ start = v.start,
+ finish = v.finish,
+ vararg = v
+ }
+ if v.parent then
+ v.extParent[#v.extParent+1] = select
+ else
+ v.parent = select
+ end
+ v = select
+ end
+ n.value = v
+ n.range = v.finish
+ v.parent = n
+ end
+end
+
+local function parseMultiVars(n1, parser, isLocal)
+ local n2, nrest = parseVarTails(parser, isLocal)
+ skipSpace()
+ local v1, v2, vrest
+ local isSet
+ local max = 1
+ if expectAssign(not isLocal) then
+ v1, v2, vrest = parseSetValues()
+ isSet = true
+ if not v1 then
+ missExp()
+ end
+ end
+ local index = 1
+ bindValue(n1, v1, index, nil, isLocal, isSet)
+ local lastValue = v1
+ local lastVar = n1
+ if n2 then
+ max = 2
+ if not v2 then
+ index = 2
+ end
+ bindValue(n2, v2, index, lastValue, isLocal, isSet)
+ lastValue = v2 or lastValue
+ lastVar = n2
+ pushActionIntoCurrentChunk(n2)
+ end
+ if nrest then
+ for i = 1, #nrest do
+ local n = nrest[i]
+ local v = vrest and vrest[i]
+ max = i + 2
+ if not v then
+ index = index + 1
+ end
+ bindValue(n, v, index, lastValue, isLocal, isSet)
+ lastValue = v or lastValue
+ lastVar = n
+ pushActionIntoCurrentChunk(n)
+ end
+ end
+
+ if isLocal then
+ local effect = lastValue and lastValue.finish or lastVar.finish
+ n1.effect = effect
+ if n2 then
+ n2.effect = effect
+ end
+ if nrest then
+ for i = 1, #nrest do
+ nrest[i].effect = effect
+ end
+ end
+ end
+
+ if v2 and not n2 then
+ v2.redundant = {
+ max = max,
+ passed = 2,
+ }
+ pushActionIntoCurrentChunk(v2)
+ end
+ if vrest then
+ for i = 1, #vrest do
+ local v = vrest[i]
+ if not nrest or not nrest[i] then
+ v.redundant = {
+ max = max,
+ passed = i + 2,
+ }
+ pushActionIntoCurrentChunk(v)
+ end
+ end
+ end
+
+ return n1, isSet
+end
+
+local function compileExpAsAction(exp)
+ pushActionIntoCurrentChunk(exp)
+ if GetToSetMap[exp.type] then
+ skipSpace()
+ local isLocal
+ if exp.type == 'getlocal' and exp[1] == State.ENVMode then
+ exp.special = nil
+ local loc = createLocal(exp, parseLocalAttrs())
+ loc.locPos = exp.start
+ loc.effect = maxinteger
+ isLocal = true
+ skipSpace()
+ end
+ local action, isSet = parseMultiVars(exp, parseExp, isLocal)
+ if isSet
+ or action.type == 'getmethod' then
+ return action
+ end
+ end
+
+ if exp.type == 'call' then
+ if exp.node.special == 'error' then
+ for i = #Chunk, 1, -1 do
+ local block = Chunk[i]
+ if block.type == 'ifblock'
+ or block.type == 'elseifblock'
+ or block.type == 'elseblock'
+ or block.type == 'function' then
+ block.hasError = true
+ break
+ end
+ end
+ end
+ return exp
+ end
+
+ if exp.type == 'binary' then
+ if GetToSetMap[exp[1].type] then
+ local op = exp.op
+ if op.type == '==' then
+ pushError {
+ type = 'ERR_ASSIGN_AS_EQ',
+ start = op.start,
+ finish = op.finish,
+ fix = {
+ title = 'FIX_ASSIGN_AS_EQ',
+ {
+ start = op.start,
+ finish = op.finish,
+ text = '=',
+ }
+ }
+ }
+ return
+ end
+ end
+ end
+
+ pushError {
+ type = 'EXP_IN_ACTION',
+ start = exp.start,
+ finish = exp.finish,
+ }
+
+ return exp
+end
+
+local function parseLocal()
+ local locPos = getPosition(Tokens[Index], 'left')
+ Index = Index + 2
+ skipSpace()
+ local word = peekWord()
+ if not word then
+ missName()
+ return nil
+ end
+
+ if word == 'function' then
+ local func = parseFunction(true, true)
+ local name = func.name
+ if name then
+ func.name = nil
+ name.value = func
+ name.vstart = func.start
+ name.range = func.finish
+ name.locPos = locPos
+ func.parent = name
+ pushActionIntoCurrentChunk(name)
+ return name
+ else
+ missName(func.keyword[2])
+ pushActionIntoCurrentChunk(func)
+ return func
+ end
+ end
+
+ local name = parseName(true)
+ if not name then
+ missName()
+ return nil
+ end
+ local loc = createLocal(name, parseLocalAttrs())
+ loc.locPos = locPos
+ loc.effect = maxinteger
+ pushActionIntoCurrentChunk(loc)
+ skipSpace()
+ parseMultiVars(loc, parseName, true)
+
+ return loc
+end
+
+local function parseDo()
+ local doLeft = getPosition(Tokens[Index], 'left')
+ local doRight = getPosition(Tokens[Index] + 1, 'right')
+ local obj = {
+ type = 'do',
+ start = doLeft,
+ finish = doRight,
+ keyword = {
+ [1] = doLeft,
+ [2] = doRight,
+ },
+ }
+ Index = Index + 2
+ pushActionIntoCurrentChunk(obj)
+ pushChunk(obj)
+ parseActions()
+ popChunk()
+ if Tokens[Index + 1] == 'end' then
+ obj.finish = getPosition(Tokens[Index] + 2, 'right')
+ obj.keyword[3] = getPosition(Tokens[Index], 'left')
+ obj.keyword[4] = getPosition(Tokens[Index] + 2, 'right')
+ Index = Index + 2
+ else
+ missEnd(doLeft, doRight)
+ end
+ if obj.locals then
+ LocalCount = LocalCount - #obj.locals
+ end
+
+ return obj
+end
+
+local function parseReturn()
+ local returnLeft = getPosition(Tokens[Index], 'left')
+ local returnRight = getPosition(Tokens[Index] + 5, 'right')
+ Index = Index + 2
+ skipSpace()
+ local rtn = parseExpList(true)
+ if rtn then
+ rtn.type = 'return'
+ rtn.start = returnLeft
+ else
+ rtn = {
+ type = 'return',
+ start = returnLeft,
+ finish = returnRight,
+ }
+ end
+ pushActionIntoCurrentChunk(rtn)
+ for i = #Chunk, 1, -1 do
+ local block = Chunk[i]
+ if block.type == 'function'
+ or block.type == 'main' then
+ if not block.returns then
+ block.returns = {}
+ end
+ block.returns[#block.returns+1] = rtn
+ break
+ end
+ end
+ for i = #Chunk, 1, -1 do
+ local block = Chunk[i]
+ if block.type == 'ifblock'
+ or block.type == 'elseifblock'
+ or block.type == 'elseblock'
+ or block.type == 'function' then
+ block.hasReturn = true
+ break
+ end
+ end
+
+ return rtn
+end
+
+local function parseLabel()
+ local left = getPosition(Tokens[Index], 'left')
+ Index = Index + 2
+ skipSpace()
+ local label = parseName()
+ skipSpace()
+
+ if not label then
+ missName()
+ end
+
+ if Tokens[Index + 1] == '::' then
+ Index = Index + 2
+ else
+ if label then
+ missSymbol '::'
+ end
+ end
+
+ if not label then
+ return nil
+ end
+
+ label.type = 'label'
+ pushActionIntoCurrentChunk(label)
+
+ local block = guide.getBlock(label)
+ if block then
+ if not block.labels then
+ block.labels = {}
+ end
+ local name = label[1]
+ local olabel = guide.getLabel(block, name)
+ if olabel then
+ if State.version == 'Lua 5.4'
+ or block == guide.getBlock(olabel) then
+ pushError {
+ type = 'REDEFINED_LABEL',
+ start = label.start,
+ finish = label.finish,
+ relative = {
+ {
+ olabel.start,
+ olabel.finish,
+ }
+ }
+ }
+ end
+ end
+ block.labels[name] = label
+ end
+
+ if State.version == 'Lua 5.1' then
+ pushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = left,
+ finish = lastRightPosition(),
+ version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return
+ end
+ return label
+end
+
+local function parseGoTo()
+ local start = getPosition(Tokens[Index], 'left')
+ Index = Index + 2
+ skipSpace()
+
+ local action = parseName()
+ if not action then
+ missName()
+ return nil
+ end
+
+ action.type = 'goto'
+ action.keyStart = start
+
+ for i = #Chunk, 1, -1 do
+ local chunk = Chunk[i]
+ if chunk.type == 'function'
+ or chunk.type == 'main' then
+ if not chunk.gotos then
+ chunk.gotos = {}
+ end
+ chunk.gotos[#chunk.gotos+1] = action
+ break
+ end
+ end
+ for i = #Chunk, 1, -1 do
+ local chunk = Chunk[i]
+ if chunk.type == 'ifblock'
+ or chunk.type == 'elseifblock'
+ or chunk.type == 'elseblock' then
+ chunk.hasGoTo = true
+ break
+ end
+ end
+
+ pushActionIntoCurrentChunk(action)
+ return action
+end
+
+local function parseIfBlock(parent)
+ local ifLeft = getPosition(Tokens[Index], 'left')
+ local ifRight = getPosition(Tokens[Index] + 1, 'right')
+ Index = Index + 2
+ local ifblock = {
+ type = 'ifblock',
+ parent = parent,
+ start = ifLeft,
+ finish = ifRight,
+ keyword = {
+ [1] = ifLeft,
+ [2] = ifRight,
+ }
+ }
+ skipSpace()
+ local filter = parseExp()
+ if filter then
+ ifblock.filter = filter
+ ifblock.finish = filter.finish
+ filter.parent = ifblock
+ else
+ missExp()
+ end
+ skipSpace()
+ local thenToken = Tokens[Index + 1]
+ if thenToken == 'then'
+ or thenToken == 'do' then
+ ifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right')
+ ifblock.keyword[3] = getPosition(Tokens[Index], 'left')
+ ifblock.keyword[4] = ifblock.finish
+ if thenToken == 'do' then
+ pushError {
+ type = 'ERR_THEN_AS_DO',
+ start = ifblock.keyword[3],
+ finish = ifblock.keyword[4],
+ fix = {
+ title = 'FIX_THEN_AS_DO',
+ {
+ start = ifblock.keyword[3],
+ finish = ifblock.keyword[4],
+ text = 'then',
+ }
+ }
+ }
+ end
+ Index = Index + 2
+ else
+ missSymbol 'then'
+ end
+ pushChunk(ifblock)
+ parseActions()
+ popChunk()
+ ifblock.finish = lastRightPosition()
+ if ifblock.locals then
+ LocalCount = LocalCount - #ifblock.locals
+ end
+ return ifblock
+end
+
+local function parseElseIfBlock(parent)
+ local ifLeft = getPosition(Tokens[Index], 'left')
+ local ifRight = getPosition(Tokens[Index] + 5, 'right')
+ local elseifblock = {
+ type = 'elseifblock',
+ parent = parent,
+ start = ifLeft,
+ finish = ifRight,
+ keyword = {
+ [1] = ifLeft,
+ [2] = ifRight,
+ }
+ }
+ Index = Index + 2
+ skipSpace()
+ local filter = parseExp()
+ if filter then
+ elseifblock.filter = filter
+ elseifblock.finish = filter.finish
+ filter.parent = elseifblock
+ else
+ missExp()
+ end
+ skipSpace()
+ local thenToken = Tokens[Index + 1]
+ if thenToken == 'then'
+ or thenToken == 'do' then
+ elseifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right')
+ elseifblock.keyword[3] = getPosition(Tokens[Index], 'left')
+ elseifblock.keyword[4] = elseifblock.finish
+ if thenToken == 'do' then
+ pushError {
+ type = 'ERR_THEN_AS_DO',
+ start = elseifblock.keyword[3],
+ finish = elseifblock.keyword[4],
+ fix = {
+ title = 'FIX_THEN_AS_DO',
+ {
+ start = elseifblock.keyword[3],
+ finish = elseifblock.keyword[4],
+ text = 'then',
+ }
+ }
+ }
+ end
+ Index = Index + 2
+ else
+ missSymbol 'then'
+ end
+ pushChunk(elseifblock)
+ parseActions()
+ popChunk()
+ elseifblock.finish = lastRightPosition()
+ if elseifblock.locals then
+ LocalCount = LocalCount - #elseifblock.locals
+ end
+ return elseifblock
+end
+
+local function parseElseBlock(parent)
+ local ifLeft = getPosition(Tokens[Index], 'left')
+ local ifRight = getPosition(Tokens[Index] + 3, 'right')
+ local elseblock = {
+ type = 'elseblock',
+ parent = parent,
+ start = ifLeft,
+ finish = ifRight,
+ keyword = {
+ [1] = ifLeft,
+ [2] = ifRight,
+ }
+ }
+ Index = Index + 2
+ skipSpace()
+ pushChunk(elseblock)
+ parseActions()
+ popChunk()
+ elseblock.finish = lastRightPosition()
+ if elseblock.locals then
+ LocalCount = LocalCount - #elseblock.locals
+ end
+ return elseblock
+end
+
+local function parseIf()
+ local token = Tokens[Index + 1]
+ local left = getPosition(Tokens[Index], 'left')
+ local action = {
+ type = 'if',
+ start = left,
+ finish = getPosition(Tokens[Index] + #token - 1, 'right'),
+ }
+ pushActionIntoCurrentChunk(action)
+ if token ~= 'if' then
+ missSymbol('if', left, left)
+ end
+ local hasElse
+ while true do
+ local word = Tokens[Index + 1]
+ local child
+ if word == 'if' then
+ child = parseIfBlock(action)
+ elseif word == 'elseif' then
+ child = parseElseIfBlock(action)
+ elseif word == 'else' then
+ child = parseElseBlock(action)
+ end
+ if not child then
+ break
+ end
+ if hasElse then
+ pushError {
+ type = 'BLOCK_AFTER_ELSE',
+ start = child.start,
+ finish = child.finish,
+ }
+ end
+ if word == 'else' then
+ hasElse = true
+ end
+ action[#action+1] = child
+ action.finish = child.finish
+ skipSpace()
+ end
+
+ if Tokens[Index + 1] == 'end' then
+ action.finish = getPosition(Tokens[Index] + 2, 'right')
+ Index = Index + 2
+ else
+ missEnd(action[1].keyword[1], action[1].keyword[2])
+ end
+
+ return action
+end
+
+local function parseFor()
+ local action = {
+ type = 'for',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + 2, 'right'),
+ keyword = {},
+ }
+ action.keyword[1] = action.start
+ action.keyword[2] = action.finish
+ Index = Index + 2
+ pushActionIntoCurrentChunk(action)
+ pushChunk(action)
+ skipSpace()
+ local nameOrList = parseNameOrList(action)
+ if not nameOrList then
+ missName()
+ end
+ skipSpace()
+ -- for i =
+ if expectAssign() then
+ action.type = 'loop'
+
+ skipSpace()
+ local expList = parseExpList()
+ local name
+ if nameOrList then
+ if nameOrList.type == 'name' then
+ name = nameOrList
+ else
+ name = nameOrList[1]
+ end
+ end
+ if name then
+ local loc = createLocal(name)
+ loc.parent = action
+ action.finish = name.finish
+ action.loc = loc
+ end
+ if expList then
+ expList.parent = action
+ local value = expList[1]
+ if value then
+ value.parent = expList
+ action.init = value
+ action.finish = expList[#expList].finish
+ end
+ local max = expList[2]
+ if max then
+ max.parent = expList
+ action.max = max
+ action.finish = max.finish
+ else
+ pushError {
+ type = 'MISS_LOOP_MAX',
+ start = lastRightPosition(),
+ finish = lastRightPosition(),
+ }
+ end
+ local step = expList[3]
+ if step then
+ step.parent = expList
+ action.step = step
+ action.finish = step.finish
+ end
+ else
+ pushError {
+ type = 'MISS_LOOP_MIN',
+ start = lastRightPosition(),
+ finish = lastRightPosition(),
+ }
+ end
+
+ if action.loc then
+ action.loc.effect = action.finish
+ end
+ elseif Tokens[Index + 1] == 'in' then
+ action.type = 'in'
+ local inLeft = getPosition(Tokens[Index], 'left')
+ local inRight = getPosition(Tokens[Index] + 1, 'right')
+ Index = Index + 2
+ skipSpace()
+
+ local exps = parseExpList()
+
+ action.finish = inRight
+ action.keyword[3] = inLeft
+ action.keyword[4] = inRight
+
+ local list
+ if nameOrList and nameOrList.type == 'name' then
+ list = {
+ type = 'list',
+ start = nameOrList.start,
+ finish = nameOrList.finish,
+ parent = action,
+ [1] = nameOrList,
+ }
+ else
+ list = nameOrList
+ end
+
+ if exps then
+ local lastExp = exps[#exps]
+ if lastExp then
+ action.finish = lastExp.finish
+ end
+
+ action.exps = exps
+ exps.parent = action
+ for i = 1, #exps do
+ local exp = exps[i]
+ exp.parent = exps
+ end
+ else
+ missExp()
+ end
+
+ if list then
+ local lastName = list[#list]
+ list.range = lastName and lastName.range or inRight
+ action.keys = list
+ for i = 1, #list do
+ local loc = createLocal(list[i])
+ loc.parent = action
+ loc.effect = action.finish
+ end
+ end
+ else
+ missSymbol 'in'
+ end
+
+ skipSpace()
+ local doToken = Tokens[Index + 1]
+ if doToken == 'do'
+ or doToken == 'then' then
+ local left = getPosition(Tokens[Index], 'left')
+ local right = getPosition(Tokens[Index] + #doToken - 1, 'right')
+ action.finish = left
+ action.keyword[#action.keyword+1] = left
+ action.keyword[#action.keyword+1] = right
+ if doToken == 'then' then
+ pushError {
+ type = 'ERR_DO_AS_THEN',
+ start = left,
+ finish = right,
+ fix = {
+ title = 'FIX_DO_AS_THEN',
+ {
+ start = left,
+ finish = right,
+ text = 'do',
+ }
+ }
+ }
+ end
+ Index = Index + 2
+ else
+ missSymbol 'do'
+ end
+
+ skipSpace()
+ parseActions()
+ popChunk()
+
+ skipSpace()
+ if Tokens[Index + 1] == 'end' then
+ action.finish = getPosition(Tokens[Index] + 2, 'right')
+ action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left')
+ action.keyword[#action.keyword+1] = action.finish
+ Index = Index + 2
+ else
+ missEnd(action.keyword[1], action.keyword[2])
+ end
+
+ if action.locals then
+ LocalCount = LocalCount - #action.locals
+ end
+
+ return action
+end
+
+local function parseWhile()
+ local action = {
+ type = 'while',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + 4, 'right'),
+ keyword = {},
+ }
+ action.keyword[1] = action.start
+ action.keyword[2] = action.finish
+ Index = Index + 2
+
+ skipSpace()
+ local nextToken = Tokens[Index + 1]
+ local filter = nextToken ~= 'do'
+ and nextToken ~= 'then'
+ and parseExp()
+ if filter then
+ action.filter = filter
+ action.finish = filter.finish
+ filter.parent = action
+ else
+ missExp()
+ end
+
+ skipSpace()
+ local doToken = Tokens[Index + 1]
+ if doToken == 'do'
+ or doToken == 'then' then
+ local left = getPosition(Tokens[Index], 'left')
+ local right = getPosition(Tokens[Index] + #doToken - 1, 'right')
+ action.finish = left
+ action.keyword[#action.keyword+1] = left
+ action.keyword[#action.keyword+1] = right
+ if doToken == 'then' then
+ pushError {
+ type = 'ERR_DO_AS_THEN',
+ start = left,
+ finish = right,
+ fix = {
+ title = 'FIX_DO_AS_THEN',
+ {
+ start = left,
+ finish = right,
+ text = 'do',
+ }
+ }
+ }
+ end
+ Index = Index + 2
+ else
+ missSymbol 'do'
+ end
+
+ pushActionIntoCurrentChunk(action)
+ pushChunk(action)
+ skipSpace()
+ parseActions()
+ popChunk()
+
+ skipSpace()
+ if Tokens[Index + 1] == 'end' then
+ action.finish = getPosition(Tokens[Index] + 2, 'right')
+ action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left')
+ action.keyword[#action.keyword+1] = action.finish
+ Index = Index + 2
+ else
+ missEnd(action.keyword[1], action.keyword[2])
+ end
+
+ if action.locals then
+ LocalCount = LocalCount - #action.locals
+ end
+
+ return action
+end
+
+local function parseRepeat()
+ local action = {
+ type = 'repeat',
+ start = getPosition(Tokens[Index], 'left'),
+ finish = getPosition(Tokens[Index] + 5, 'right'),
+ keyword = {},
+ }
+ action.keyword[1] = action.start
+ action.keyword[2] = action.finish
+ Index = Index + 2
+
+ pushActionIntoCurrentChunk(action)
+ pushChunk(action)
+ skipSpace()
+ parseActions()
+
+ skipSpace()
+ if Tokens[Index + 1] == 'until' then
+ action.finish = getPosition(Tokens[Index] + 4, 'right')
+ action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left')
+ action.keyword[#action.keyword+1] = action.finish
+ Index = Index + 2
+
+ skipSpace()
+ local filter = parseExp()
+ if filter then
+ action.filter = filter
+ filter.parent = action
+ else
+ missExp()
+ end
+
+ else
+ missSymbol 'until'
+ end
+
+ popChunk()
+ if action.filter then
+ action.finish = action.filter.finish
+ end
+
+ if action.locals then
+ LocalCount = LocalCount - #action.locals
+ end
+
+ return action
+end
+
+local function parseBreak()
+ local returnLeft = getPosition(Tokens[Index], 'left')
+ local returnRight = getPosition(Tokens[Index] + #Tokens[Index + 1] - 1, 'right')
+ Index = Index + 2
+ skipSpace()
+ local action = {
+ type = 'break',
+ start = returnLeft,
+ finish = returnRight,
+ }
+
+ local ok
+ for i = #Chunk, 1, -1 do
+ local chunk = Chunk[i]
+ if chunk.type == 'function' then
+ break
+ end
+ if chunk.type == 'while'
+ or chunk.type == 'in'
+ or chunk.type == 'loop'
+ or chunk.type == 'repeat'
+ or chunk.type == 'for' then
+ if not chunk.breaks then
+ chunk.breaks = {}
+ end
+ chunk.breaks[#chunk.breaks+1] = action
+ ok = true
+ break
+ end
+ end
+ for i = #Chunk, 1, -1 do
+ local chunk = Chunk[i]
+ if chunk.type == 'ifblock'
+ or chunk.type == 'elseifblock'
+ or chunk.type == 'elseblock' then
+ chunk.hasBreak = true
+ break
+ end
+ end
+ if not ok and Mode == 'Lua' then
+ pushError {
+ type = 'BREAK_OUTSIDE',
+ start = action.start,
+ finish = action.finish,
+ }
+ end
+
+ pushActionIntoCurrentChunk(action)
+ return action
+end
+
+function parseAction()
+ local token = Tokens[Index + 1]
+
+ if token == '::' then
+ return parseLabel()
+ end
+
+ if token == 'local' then
+ return parseLocal()
+ end
+
+ if token == 'if'
+ or token == 'elseif'
+ or token == 'else' then
+ return parseIf()
+ end
+
+ if token == 'for' then
+ return parseFor()
+ end
+
+ if token == 'do' then
+ return parseDo()
+ end
+
+ if token == 'return' then
+ return parseReturn()
+ end
+
+ if token == 'break' then
+ return parseBreak()
+ end
+
+ if token == 'continue' and State.options.nonstandardSymbol['continue'] then
+ return parseBreak()
+ end
+
+ if token == 'while' then
+ return parseWhile()
+ end
+
+ if token == 'repeat' then
+ return parseRepeat()
+ end
+
+ if token == 'goto' and isKeyWord 'goto' then
+ return parseGoTo()
+ end
+
+ if token == 'function' then
+ local exp = parseFunction(false, true)
+ local name = exp.name
+ if name then
+ exp.name = nil
+ name.type = GetToSetMap[name.type]
+ name.value = exp
+ name.vstart = exp.start
+ name.range = exp.finish
+ exp.parent = name
+ if name.type == 'setlocal' then
+ local loc = name.node
+ if loc.attrs then
+ pushError {
+ type = 'SET_CONST',
+ start = name.start,
+ finish = name.finish,
+ }
+ end
+ end
+ pushActionIntoCurrentChunk(name)
+ return name
+ else
+ pushActionIntoCurrentChunk(exp)
+ missName(exp.keyword[2])
+ return exp
+ end
+ end
+
+ local exp = parseExp(true)
+ if exp then
+ local action = compileExpAsAction(exp)
+ if action then
+ return action
+ end
+ end
+ return nil, true
+end
+
+local function skipFirstComment()
+ if Tokens[Index + 1] ~= '#' then
+ return
+ end
+ while true do
+ Index = Index + 2
+ local token = Tokens[Index + 1]
+ if not token then
+ break
+ end
+ if NLMap[token] then
+ skipNL()
+ break
+ end
+ end
+end
+
+local function parseLua()
+ local main = {
+ type = 'main',
+ start = 0,
+ finish = 0,
+ }
+ pushChunk(main)
+ createLocal{
+ type = 'local',
+ start = -1,
+ finish = -1,
+ effect = -1,
+ parent = main,
+ tag = '_ENV',
+ special= '_G',
+ [1] = State.ENVMode,
+ }
+ LocalCount = 0
+ skipFirstComment()
+ while true do
+ parseActions()
+ if Index <= #Tokens then
+ unknownSymbol()
+ Index = Index + 2
+ else
+ break
+ end
+ end
+ popChunk()
+ main.finish = getPosition(#Lua, 'right')
+
+ return main
+end
+
+local function initState(lua, version, options)
+ Lua = lua
+ Line = 0
+ LineOffset = 1
+ LastTokenFinish = 0
+ LocalCount = 0
+ Chunk = {}
+ Tokens = tokens(lua)
+ Index = 1
+ ---@class parser.state
+ ---@field uri uri
+ local state = {
+ version = version,
+ lua = lua,
+ ast = {},
+ errs = {},
+ comms = {},
+ lines = {
+ [0] = 1,
+ },
+ options = options or {},
+ }
+ if not state.options.nonstandardSymbol then
+ state.options.nonstandardSymbol = {}
+ end
+ State = state
+ if version == 'Lua 5.1' or version == 'LuaJIT' then
+ state.ENVMode = '@fenv'
+ else
+ state.ENVMode = '_ENV'
+ end
+
+ pushError = function (err)
+ local errs = state.errs
+ if err.finish < err.start then
+ err.finish = err.start
+ end
+ local last = errs[#errs]
+ if last then
+ if last.start <= err.start and last.finish >= err.finish then
+ return
+ end
+ end
+ err.level = err.level or 'Error'
+ errs[#errs+1] = err
+ return err
+ end
+
+ state.pushError = pushError
+end
+
+return function (lua, mode, version, options)
+ Mode = mode
+ initState(lua, version, options)
+ skipSpace()
+ if mode == 'Lua' then
+ State.ast = parseLua()
+ elseif mode == 'Nil' then
+ State.ast = parseNil()
+ elseif mode == 'Boolean' then
+ State.ast = parseBoolean()
+ elseif mode == 'String' then
+ State.ast = parseString()
+ elseif mode == 'Number' then
+ State.ast = parseNumber()
+ elseif mode == 'Name' then
+ State.ast = parseName()
+ elseif mode == 'Exp' then
+ State.ast = parseExp()
+ elseif mode == 'Action' then
+ State.ast = parseAction()
+ end
+
+ if State.ast then
+ State.ast.state = State
+ end
+
+ while true do
+ if Index <= #Tokens then
+ unknownSymbol()
+ Index = Index + 2
+ else
+ break
+ end
+ end
+
+ return State
end
diff --git a/script/parser/grammar.lua b/script/parser/grammar.lua
deleted file mode 100644
index a28b7950..00000000
--- a/script/parser/grammar.lua
+++ /dev/null
@@ -1,573 +0,0 @@
-local re = require 'parser.relabel'
-local m = require 'lpeglabel'
-local ast = require 'parser.ast'
-
-local scriptBuf = ''
-local compiled = {}
-local defs = ast.defs
-
--- goto 可以作为名字,合法性之后处理
-local RESERVED = {
- ['and'] = true,
- ['break'] = true,
- ['do'] = true,
- ['else'] = true,
- ['elseif'] = true,
- ['end'] = true,
- ['false'] = true,
- ['for'] = true,
- ['function'] = true,
- ['if'] = true,
- ['in'] = true,
- ['local'] = true,
- ['nil'] = true,
- ['not'] = true,
- ['or'] = true,
- ['repeat'] = true,
- ['return'] = true,
- ['then'] = true,
- ['true'] = true,
- ['until'] = true,
- ['while'] = true,
-}
-
-defs.nl = (m.P'\r\n' + m.S'\r\n')
-defs.s = m.S' \t'
-defs.S = - defs.s
-defs.ea = '\a'
-defs.eb = '\b'
-defs.ef = '\f'
-defs.en = '\n'
-defs.er = '\r'
-defs.et = '\t'
-defs.ev = '\v'
-defs['nil'] = m.Cp() / function () return nil end
-defs['false'] = m.Cp() / function () return false end
-
-defs.NotReserved = function (_, _, str)
- if RESERVED[str] then
- return false
- end
- return true
-end
-defs.Reserved = function (_, _, str)
- if RESERVED[str] then
- return true
- end
- return false
-end
-defs.None = function () end
-defs.np = m.Cp() / function (n) return n+1 end
-defs.NameBody = m.R('az', 'AZ', '__', '\x80\xff') * m.R('09', 'az', 'AZ', '__', '\x80\xff')^0
-defs.NoNil = function (o)
- if o == nil then
- return
- end
- return o
-end
-
-m.setmaxstack(1000)
-
-local eof = re.compile '!. / %{SYNTAX_ERROR}'
-
-local function grammar(tag)
- return function (script)
- scriptBuf = script .. '\r\n' .. scriptBuf
- compiled[tag] = re.compile(scriptBuf, defs) * eof
- end
-end
-
-local function errorpos(pos, err)
- return {
- type = 'UNKNOWN',
- start = pos or 0,
- finish = pos or 0,
- err = err,
- }
-end
-
-grammar 'Comment' [[
-Comment <- LongComment
- / '--' ShortComment
-LongComment <- ({} '--[' {} {:eq: '='* :} {} '[' %nl?
- {(!CommentClose .)*}
- ((CommentClose / %nil) {}))
- -> LongComment
- / (
- {} '/*' {} %nl?
- {(!'*/' .)*}
- {} '*/' {}
- )
- -> CLongComment
-CommentClose <- {']' =eq ']'}
-ShortComment <- ({} {(!%nl .)*} {})
- -> ShortComment
-]]
-
-grammar 'Sp' [[
-Sp <- (Comment / %nl / %s)*
-Sps <- (Comment / %nl / %s)+
-]]
-
-grammar 'Common' [[
-Word <- [a-zA-Z0-9_]
-Cut <- !Word
-X16 <- [a-fA-F0-9]
-Rest <- (!%nl .)*
-
-AND <- Sp {'and'} Cut
-BREAK <- Sp 'break' Cut
-FALSE <- Sp 'false' Cut
-GOTO <- Sp 'goto' Cut
-LOCAL <- Sp 'local' Cut
-NIL <- Sp 'nil' Cut
-NOT <- Sp 'not' Cut
-OR <- Sp {'or'} Cut
-RETURN <- Sp 'return' Cut
-TRUE <- Sp 'true' Cut
-CONTINUE <- Sp 'continue' Cut
-
-DO <- Sp {} 'do' {} Cut
- / Sp({} 'then' {} Cut) -> ErrDo
-IF <- Sp {} 'if' {} Cut
-ELSE <- Sp {} 'else' {} Cut
-ELSEIF <- Sp {} 'elseif' {} Cut
-END <- Sp {} 'end' {} Cut
-FOR <- Sp {} 'for' {} Cut
-FUNCTION <- Sp {} 'function' {} Cut
-IN <- Sp {} 'in' {} Cut
-REPEAT <- Sp {} 'repeat' {} Cut
-THEN <- Sp {} 'then' {} Cut
- / Sp({} 'do' {} Cut) -> ErrThen
-UNTIL <- Sp {} 'until' {} Cut
-WHILE <- Sp {} 'while' {} Cut
-
-
-Esc <- '\' -> ''
- EChar
-EChar <- 'a' -> ea
- / 'b' -> eb
- / 'f' -> ef
- / 'n' -> en
- / 'r' -> er
- / 't' -> et
- / 'v' -> ev
- / '\'
- / '"'
- / "'"
- / %nl
- / ('z' (%nl / %s)*) -> ''
- / ({} 'x' {X16 X16}) -> Char16
- / ([0-9] [0-9]? [0-9]?) -> Char10
- / ('u{' {} {Word*} '}') -> CharUtf8
- -- 错误处理
- / 'x' {} -> MissEscX
- / 'u' !'{' {} -> MissTL
- / 'u{' Word* !'}' {} -> MissTR
- / {} -> ErrEsc
-
-BOR <- Sp {'|'}
-BXOR <- Sp {'~'} !'='
-BAND <- Sp {'&'}
-Bshift <- Sp {BshiftList}
-BshiftList <- '<<'
- / '>>'
-Concat <- Sp {'..'}
-Adds <- Sp {AddsList}
-AddsList <- '+'
- / '-'
-Muls <- Sp {MulsList}
-MulsList <- '*'
- / '//'
- / '/'
- / '%'
-Unary <- Sp {} {UnaryList}
-UnaryList <- NOT
- / '#'
- / '-'
- / '~' !'='
-POWER <- Sp {'^'}
-
-BinaryOp <-( Sp {} {'or' / '||'} Cut
- / Sp {} {'and' / '&&'} Cut
- / Sp {} {'<=' / '>=' / '<'!'<' / '>'!'>' / '~=' / '==' / '!='}
- / Sp {} ({} '=' {}) -> ErrEQ
- / Sp {} ({} '!=' {}) -> ErrUEQ
- / Sp {} {'|'}
- / Sp {} {'~'}
- / Sp {} {'&'}
- / Sp {} {'<<' / '>>'}
- / Sp {} {'..'} !'.'
- / Sp {} {'+' / '-'}
- / Sp {} {'*' / '//' / '/' / '%'}
- / Sp {} {'^'}
- )-> BinaryOp
-UnaryOp <-( Sp {} {'not' Cut / '#' / '~' !'=' / '-' !'-' / '!' !'='}
- )-> UnaryOp
-
-PL <- Sp '('
-PR <- Sp ')'
-BL <- Sp '[' !'[' !'='
-BR <- Sp ']'
-TL <- Sp '{'
-TR <- Sp '}'
-COMMA <- Sp ({} ',')
- -> COMMA
-SEMICOLON <- Sp ({} ';')
- -> SEMICOLON
-DOTS <- Sp ({} '...')
- -> DOTS
-DOT <- Sp ({} '.' !'.')
- -> DOT
-COLON <- Sp ({} ':' !':')
- -> COLON
-LABEL <- Sp '::'
-ASSIGN <- Sp '=' !'='
- / Sp ({} {'+=' / '-=' / '*=' / '\='})
- -> ASSIGN
-AssignOrEQ <- Sp ({} '==' {})
- -> ErrAssign
- / ASSIGN
-
-DirtyBR <- BR / {} -> MissBR
-DirtyTR <- TR / {} -> MissTR
-DirtyPR <- PR / {} -> MissPR
-DirtyLabel <- LABEL / {} -> MissLabel
-NeedEnd <- END / {} -> MissEnd
-NeedDo <- DO / {} -> MissDo
-NeedAssign <- ASSIGN / {} -> MissAssign
-NeedComma <- COMMA / {} -> MissComma
-NeedIn <- IN / {} -> MissIn
-NeedUntil <- UNTIL / {} -> MissUntil
-NeedThen <- THEN / {} -> MissThen
-]]
-
-grammar 'Nil' [[
-Nil <- Sp ({} -> Nil) NIL
-]]
-
-grammar 'Boolean' [[
-Boolean <- Sp ({} -> True) TRUE
- / Sp ({} -> False) FALSE
-]]
-
-grammar 'String' [[
-String <- Sp ({} StringDef {})
- -> String
-StringDef <- {'"'}
- {~(Esc / !%nl !'"' .)*~} -> 1
- ('"' / {} -> MissQuote1)
- / {"'"}
- {~(Esc / !%nl !"'" .)*~} -> 1
- ("'" / {} -> MissQuote2)
- / {'`'}
- {(!%nl !'`' .)*} -> 1
- ('`' / {} -> MissQuote3)
- / ('[' {} {:eq: '='* :} {} '[' %nl?
- {(!StringClose .)*} -> 1
- (StringClose / {}))
- -> LongString
-StringClose <- ']' =eq ']'
-]]
-
-grammar 'Number' [[
-Number <- Sp ({} {~ '-'? NumberDef ~} {}) -> Number
- NumberSuffix?
- ErrNumber?
-NumberDef <- Number16 / Integer2 / Number10
-NumberSuffix<- ({} {[uU]? [lL] [lL]}) -> FFINumber
- / ({} {[iI]}) -> ImaginaryNumber
-ErrNumber <- ({} {([0-9a-zA-Z] / '.')+}) -> UnknownSymbol
-
-Number10 <- Float10 Float10Exp?
- / Integer10 Float10? Float10Exp?
-Integer10 <- [0-9]+ ('.' [0-9]*)?
-Float10 <- '.' [0-9]+
-Float10Exp <- [eE] [+-]? [0-9]+
- / ({} [eE] [+-]? {}) -> MissExponent
-
-Number16 <- '0' [xX] Float16 Float16Exp?
- / '0' [xX] Integer16 Float16? Float16Exp?
-Integer16 <- X16+ ('.' X16*)?
- / ({} {Word*}) -> MustX16
-Float16 <- '.' X16+
- / '.' ({} {Word*}) -> MustX16
-Float16Exp <- [pP] [+-]? [0-9]+
- / ({} [pP] [+-]? {}) -> MissExponent
-
-Integer2 <- ({} '0' [bB] {[01]+})
- -> Integer2
-]]
-
-grammar 'Name' [[
-Name <- Sp ({} NameBody {})
- -> Name
-NameBody <- {%NameBody}
-KeyWord <- Sp NameBody=>Reserved
-MustName <- Name / DirtyName
-DirtyName <- {} -> DirtyName
-]]
-
-grammar 'DocType' [[
-DocType <- (!%nl !')' !',' DocChar)+
-DocChar <- '(' (!%nl !')' .)+ ')'?
- / '<' (!%nl !'>' .)+ '>'?
- / .
-]]
-
-grammar 'Exp' [[
-Exp <- (UnUnit BinUnit*)
- -> Binary
-BinUnit <- (BinaryOp UnUnit?)
- -> SubBinary
-UnUnit <- Number
- / (UnaryOp+ (ExpUnit / MissExp))
- -> Unary
- / ExpUnit
-ExpUnit <- Nil
- / Boolean
- / String
- / Number
- / Dots
- / Table
- / ExpFunction
- / Simple
-
-Simple <- {| Prefix (Sp Suffix)* |}
- -> Simple
-Prefix <- Sp ({} PL DirtyExp DirtyPR {})
- -> Paren
- / Single
-Single <- !FUNCTION Name
- -> Single
-Suffix <- SuffixWithoutCall
- / ({} PL SuffixCall DirtyPR {})
- -> Call
-SuffixCall <- Sp ({} {| (COMMA / CallArg)+ |} {})
- -> PackExpList
- / %nil
-CallArg <- Sp (Name {} {'?'? ':'} Sps DocType)
- -> CallArgSnip
- / Exp->NoNil
-SuffixWithoutCall
- <- (DOT (Name / MissField))
- -> GetField
- / ({} BL DirtyExp DirtyBR {})
- -> GetIndex
- / (COLON (Name / MissMethod) NeedCall)
- -> GetMethod
- / ({} {| Table |} {})
- -> Call
- / ({} {| String |} {})
- -> Call
-NeedCall <- (!(Sp CallStart) {} -> MissPL)?
-MissField <- {} -> MissField
-MissMethod <- {} -> MissMethod
-CallStart <- PL
- / TL
- / '"'
- / "'"
- / '[' '='* '['
-
-DirtyExp <- !THEN !DO !END Exp
- / {} -> DirtyExp
-MaybeExp <- Exp / MissExp
-MissExp <- {} -> MissExp
-ExpList <- Sp {| MaybeExp (Sp ',' MaybeExp)* |}
-
-Dots <- DOTS
- -> VarArgs
-
-Table <- Sp ({} TL {| TableField* |} DirtyTR {})
- -> Table
-TableField <- COMMA
- / SEMICOLON
- / Dots
- / NewIndex
- / NewField
- / TableExp
-Index <- BL DirtyExp DirtyBR
-NewIndex <- Sp ({} Index NeedAssign DirtyExp {})
- -> NewIndex
-NewField <- Sp ({} MustName ASSIGN DirtyExp {})
- -> NewField
-TableExp <- Sp ({} Exp {})
- -> TableExp
-
-ExpFunction <- Function
- -> ExpFunction
-Function <- FunctionBody
- -> Function
-FunctionBody
- <- FUNCTION FuncName FuncArgs
- {| (!END Action)* |}
- NeedEnd
- / FUNCTION FuncName FuncArgsMiss
- {| %nil |}
- NeedEnd
-FuncName <- !END {| Single (Sp SuffixWithoutCall)* |}
- -> Simple
- / %nil
-
-FuncArgs <- Sp ({} PL {| FuncArg+ |} DirtyPR {})
- -> FuncArgs
- / PL DirtyPR %nil
-FuncArgsMiss<- {} -> MissPL DirtyPR %nil
-FuncArg <- DOTS
- / Name
- / COMMA
-
--- 纯占位,修改了 `relabel.lua` 使重复定义不抛错
-Action <- !END .
-]]
-
-grammar 'Action' [[
-Action <- Sp (CrtAction / UnkAction)
-CrtAction <- Semicolon
- / Do
- / Break
- / Return
- / Label
- / GoTo
- / If
- / For
- / While
- / Repeat
- / NamedFunction
- / LocalFunction
- / Local
- / Set
- / Continue
- / Call
- / ExpInAction
-UnkAction <- ({} {Word+})
- -> UnknownAction
- / ({} '//' {} (LongComment / ShortComment) {})
- -> CCommentPrefix
- / ({} {. (!Sps !CrtAction .)*})
- -> UnknownAction
-ExpInAction <- Sp ({} Exp {})
- -> ExpInAction
-
-Semicolon <- Sp ';'
-SimpleList <- {| Simple (Sp ',' Simple)* |}
-
-Do <- Sp ({}
- 'do' Cut
- {| (!END Action)* |}
- NeedEnd)
- -> Do
-
-Break <- Sp ({} BREAK {})
- -> Break
-
-Continue <- Sp ({} CONTINUE {})
- => RTContinue
- -> Continue
-
-Return <- Sp ({} RETURN ReturnExpList {})
- -> Return
-ReturnExpList
- <- Sp !END !ELSEIF !ELSE {| Exp (Sp ',' MaybeExp)* |}
- / Sp {| %nil |}
-
-Label <- Sp ({} LABEL MustName DirtyLabel {})
- -> Label
-
-GoTo <- Sp ({} GOTO MustName {})
- -> GoTo
-
-If <- Sp ({} {| IfHead IfBody* |} NeedEnd)
- -> If
-
-IfHead <- Sp (IfPart {}) -> IfBlock
- / Sp (ElseIfPart {}) -> ElseIfBlock
- / Sp (ElsePart {}) -> ElseBlock
-IfBody <- Sp (ElseIfPart {}) -> ElseIfBlock
- / Sp (ElsePart {}) -> ElseBlock
-IfPart <- IF DirtyExp NeedThen
- {| (!ELSEIF !ELSE !END Action)* |}
-ElseIfPart <- ELSEIF DirtyExp NeedThen
- {| (!ELSEIF !ELSE !END Action)* |}
-ElsePart <- ELSE
- {| (!ELSEIF !ELSE !END Action)* |}
-
-For <- Loop / In
-
-Loop <- LoopBody
- -> Loop
-LoopBody <- FOR LoopArgs NeedDo
- {} {| (!END Action)* |}
- NeedEnd
-LoopArgs <- MustName AssignOrEQ
- ({} {| (COMMA / !DO !END Exp->NoNil)* |} {})
- -> PackLoopArgs
-
-In <- InBody
- -> In
-InBody <- FOR InNameList NeedIn InExpList NeedDo
- {} {| (!END Action)* |}
- NeedEnd
-InNameList <- ({} {| (COMMA / !IN !DO !END Name->NoNil)* |} {})
- -> PackInNameList
-InExpList <- ({} {| (COMMA / !DO !DO !END Exp->NoNil)* |} {})
- -> PackInExpList
-
-While <- WhileBody
- -> While
-WhileBody <- WHILE DirtyExp NeedDo
- {| (!END Action)* |}
- NeedEnd
-
-Repeat <- (RepeatBody {})
- -> Repeat
-RepeatBody <- REPEAT
- {| (!UNTIL Action)* |}
- NeedUntil DirtyExp
-
-LocalAttr <- {| (Sp '<' Sp MustName Sp LocalAttrEnd)+ |}
- -> LocalAttr
-LocalAttrEnd<- ({} '>' &'=') -> MissSpaceBetween
- / '>'
- / {} -> MissGT
-Local <- Sp ({} LOCAL LocalNameList ((AssignOrEQ ExpList) / %nil) {})
- -> Local
-Set <- Sp ({} SimpleList AssignOrEQ {} ExpList {})
- -> Set
-LocalNameList
- <- {| LocalName (Sp ',' LocalName)* |}
-LocalName <- (MustName LocalAttr?)
- -> LocalName
-
-NamedFunction
- <- Function
- -> NamedFunction
-
-Call <- Simple
- -> SimpleCall
-
-LocalFunction
- <- Sp ({} LOCAL Function)
- -> LocalFunction
-]]
-
-grammar 'Lua' [[
-Lua <- Head?
- ({} {| Action* |} {}) -> Lua
- Sp
-Head <- '#' (!%nl .)*
-]]
-
-return function (lua, mode)
- local gram = compiled[mode] or compiled['Lua']
- local r, _, pos = gram:match(lua)
- if not r then
- local err = errorpos(pos)
- return nil, err
- end
- if type(r) ~= 'table' then
- return nil
- end
-
- return r
-end
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 06169b09..e4faf47f 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -4,15 +4,18 @@ local type = type
---@class parser.object
---@field bindDocs parser.object[]
---@field bindGroup parser.object[]
----@field bindSources parser.object[]
+---@field bindSource parser.object
---@field value parser.object
---@field parent parser.object
---@field type string
---@field special string
---@field tag string
----@field args parser.object[]
+---@field args { [integer]: parser.object, start: integer, finish: integer }
---@field locals parser.object[]
----@field returns parser.object[]
+---@field returns? parser.object[]
+---@field breaks? parser.object[]
+---@field exps parser.object[]
+---@field keys parser.object
---@field uri uri
---@field start integer
---@field finish integer
@@ -42,6 +45,7 @@ local type = type
---@field exp parser.object
---@field alias parser.object
---@field class parser.object
+---@field enum parser.object
---@field vararg parser.object
---@field param parser.object
---@field overload parser.object
@@ -49,6 +53,8 @@ local type = type
---@field upvalues table<string, string[]>
---@field ref parser.object[]
---@field returnIndex integer
+---@field assignIndex integer
+---@field docIndex integer
---@field docs parser.object[]
---@field state table
---@field comment table
@@ -65,6 +71,8 @@ local type = type
---@field hasGoTo? true
---@field hasReturn? true
---@field hasBreak? true
+---@field hasError? true
+---@field [integer] parser.object|any
---@field _root parser.object
---@class guide
@@ -134,6 +142,7 @@ local childMap = {
['doc.class'] = {'class', '#extends', '#signs', 'comment'},
['doc.type'] = {'#types', 'name', 'comment'},
['doc.alias'] = {'alias', 'extends', 'comment'},
+ ['doc.enum'] = {'enum', 'extends', 'comment'},
['doc.param'] = {'param', 'extends', 'comment'},
['doc.return'] = {'#returns', 'comment'},
['doc.field'] = {'field', 'extends', 'comment'},
@@ -154,6 +163,7 @@ local childMap = {
['doc.as'] = {'as'},
['doc.cast'] = {'loc', '#casts'},
['doc.cast.block'] = {'extends'},
+ ['doc.operator'] = {'op', 'exp', 'extends'}
}
---@type table<string, fun(obj: parser.object, list: parser.object[])>
@@ -194,6 +204,43 @@ end
return f
end})
+local eachChildMap = setmetatable({}, {__index = function (self, name)
+ local defs = childMap[name]
+ if not defs then
+ self[name] = false
+ return false
+ end
+ local text = {}
+ text[#text+1] = 'local obj, callback = ...'
+ for _, def in ipairs(defs) do
+ if def == '#' then
+ text[#text+1] = [[
+for i = 1, #obj do
+ callback(obj[i])
+end
+]]
+ elseif type(def) == 'string' and def:sub(1, 1) == '#' then
+ local key = def:sub(2)
+ text[#text+1] = ([[
+local childs = obj.%s
+if childs then
+ for i = 1, #childs do
+ callback(childs[i])
+ end
+end
+]]):format(key)
+ elseif type(def) == 'string' then
+ text[#text+1] = ('callback(obj.%s)'):format(def)
+ else
+ text[#text+1] = ('callback(obj[%q])'):format(def)
+ end
+ end
+ local buf = table.concat(text, '\n')
+ local f = load(buf, buf, 't')
+ self[name] = f
+ return f
+end})
+
m.actionMap = {
['main'] = {'#'},
['repeat'] = {'#'},
@@ -209,34 +256,8 @@ m.actionMap = {
['funcargs'] = {'#'},
}
-local inf = 1 / 0
-local nan = 0 / 0
-
-local function isInteger(n)
- if math.type then
- return math.type(n) == 'integer'
- else
- return type(n) == 'number' and n % 1 == 0
- end
-end
-
-local function formatNumber(n)
- if n == inf
- or n == -inf
- or n == nan
- or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则
- return ('%q'):format(n)
- end
- if isInteger(n) then
- return tostring(n)
- end
- local str = ('%.10f'):format(n)
- str = str:gsub('%.?0*$', '')
- return str
-end
-
--- 是否是字面量
----@param obj parser.object
+---@param obj table
---@return boolean
function m.isLiteral(obj)
local tp = obj.type
@@ -252,6 +273,8 @@ function m.isLiteral(obj)
or tp == 'doc.type.string'
or tp == 'doc.type.integer'
or tp == 'doc.type.boolean'
+ or tp == 'doc.type.code'
+ or tp == 'doc.type.array'
end
--- 获取字面量
@@ -273,7 +296,7 @@ end
--- 寻找父函数
---@param obj parser.object
----@return parser.object
+---@return parser.object?
function m.getParentFunction(obj)
for _ = 1, 10000 do
obj = obj.parent
@@ -290,7 +313,7 @@ end
--- 寻找所在区块
---@param obj parser.object
----@return parser.object
+---@return parser.object?
function m.getBlock(obj)
for _ = 1, 10000 do
if not obj then
@@ -319,7 +342,7 @@ end
--- 寻找所在父区块
---@param obj parser.object
----@return parser.object
+---@return parser.object?
function m.getParentBlock(obj)
for _ = 1, 10000 do
obj = obj.parent
@@ -336,7 +359,7 @@ end
--- 寻找所在可break的父区块
---@param obj parser.object
----@return parser.object
+---@return parser.object?
function m.getBreakBlock(obj)
for _ = 1, 10000 do
obj = obj.parent
@@ -373,7 +396,7 @@ end
--- 寻找所在父类型
---@param obj parser.object
----@return parser.object
+---@return parser.object?
function m.getParentType(obj, want)
for _ = 1, 10000 do
obj = obj.parent
@@ -406,8 +429,7 @@ function m.getRoot(obj)
end
local parent = obj.parent
if not parent then
- log.error('Can not find out root:', obj.type)
- return nil
+ error('Can not find out root:' .. tostring(obj.type))
end
obj = parent
end
@@ -436,56 +458,54 @@ function m.getENV(source, start)
or m.getLocal(source, '@fenv', start)
end
---- 寻找函数的不定参数,返回不定参在第几个参数上,以及该参数对象。
---- 如果函数是主函数,则返回`0, nil`。
----@return table
----@return integer
-function m.getFunctionVarArgs(func)
- if func.type == 'main' then
- return 0, nil
- end
- if func.type ~= 'function' then
- return nil, nil
- end
- local args = func.args
- if not args then
- return nil, nil
- end
- for i = 1, #args do
- local arg = args[i]
- if arg.type == '...' then
- return i, arg
- end
- end
- return nil, nil
-end
-
--- 获取指定区块中可见的局部变量
---@param source parser.object
---@param name string # 变量名
---@param pos integer # 可见位置
---@return parser.object?
function m.getLocal(source, name, pos)
- local root = m.getRoot(source)
- local res
- m.eachSourceContain(root, pos, function (src)
- local locals = src.locals
- if not locals then
- return
+ local block = source
+ -- find nearest source
+ for _ = 1, 10000 do
+ if not block then
+ return nil
end
- for i = 1, #locals do
- local loc = locals[i]
- if loc.effect > pos then
- break
- end
- if loc[1] == name then
- if not res or res.effect < loc.effect then
- res = loc
+ if block.start <= pos
+ and block.finish >= pos
+ and blockTypes[block.type] then
+ break
+ end
+ block = block.parent
+ end
+
+ m.eachSourceContain(block, pos, function (src)
+ if blockTypes[src.type]
+ and (src.finish - src.start) < (block.finish - src.start) then
+ block = src
+ end
+ end)
+
+ for _ = 1, 10000 do
+ if not block then
+ break
+ end
+ local res
+ if block.locals then
+ for _, loc in ipairs(block.locals) do
+ if loc[1] == name
+ and loc.effect <= pos then
+ if not res or res.effect < loc.effect then
+ res = loc
+ end
end
end
end
- end)
- return res
+ if res then
+ return res
+ end
+ block = block.parent
+ end
+ return nil
end
--- 获取指定区块中所有的可见局部变量名称
@@ -507,25 +527,25 @@ function m.getVisibleLocals(block, pos)
end
--- 获取指定区块中可见的标签
----@param block table
----@param name string {comment = '标签名'}
+---@param block parser.object
+---@param name string
function m.getLabel(block, name)
- block = m.getBlock(block)
+ local current = m.getBlock(block)
for _ = 1, 10000 do
- if not block then
+ if not current then
return nil
end
- local labels = block.labels
+ local labels = current.labels
if labels then
local label = labels[name]
if label then
return label
end
end
- if block.type == 'function' then
+ if current.type == 'function' then
return nil
end
- block = m.getParentBlock(block)
+ current = m.getParentBlock(current)
end
error('guide.getLocal overstack')
end
@@ -749,6 +769,16 @@ function m.eachSource(ast, callback)
end
end
+---@param source parser.object
+---@param callback fun(src: parser.object)
+function m.eachChild(source, callback)
+ local f = eachChildMap[source.type]
+ if not f then
+ return
+ end
+ f(source, callback)
+end
+
--- 获取指定的 special
function m.eachSpecialOf(ast, name, callback)
local root = m.getRoot(ast)
@@ -797,6 +827,8 @@ function m.positionToOffset(state, position)
return m.positionToOffsetByLines(state.lines, position)
end
+---@param lines integer[]
+---@param offset integer
function m.offsetToPositionByLines(lines, offset)
local left = 0
local right = #lines
@@ -854,11 +886,14 @@ local isSetMap = {
['label'] = true,
['doc.class'] = true,
['doc.alias'] = true,
+ ['doc.enum'] = true,
['doc.field'] = true,
['doc.class.name'] = true,
['doc.alias.name'] = true,
+ ['doc.enum.name'] = true,
['doc.field.name'] = true,
['doc.type.field'] = true,
+ ['doc.type.array'] = true,
}
function m.isSet(source)
local tp = source.type
@@ -972,6 +1007,8 @@ function m.getKeyName(obj)
return obj.class[1]
elseif tp == 'doc.alias' then
return obj.alias[1]
+ elseif tp == 'doc.enum' then
+ return obj.enum[1]
elseif tp == 'doc.field' then
return obj.field[1]
elseif tp == 'doc.field.name' then
@@ -1035,6 +1072,8 @@ function m.getKeyType(obj)
return 'string'
elseif tp == 'doc.alias' then
return 'string'
+ elseif tp == 'doc.enum' then
+ return 'string'
elseif tp == 'doc.field' then
return type(obj.field[1])
elseif tp == 'doc.type.field' then
@@ -1058,9 +1097,9 @@ end
--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。
---@param a table
---@param b table
----@return string|boolean mode
----@return table pathA?
----@return table pathB?
+---@return string|false mode
+---@return table? pathA
+---@return table? pathB
function m.getPath(a, b, sameFunction)
--- 首先测试双方在同一个函数内
if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then
@@ -1084,16 +1123,20 @@ function m.getPath(a, b, sameFunction)
local pathB = {}
for _ = 1, 1000 do
objA = m.getParentBlock(objA)
- pathA[#pathA+1] = objA
- if (not sameFunction and objA.type == 'function') or objA.type == 'main' then
- break
+ if objA then
+ pathA[#pathA+1] = objA
+ if (not sameFunction and objA.type == 'function') or objA.type == 'main' then
+ break
+ end
end
end
for _ = 1, 1000 do
objB = m.getParentBlock(objB)
- pathB[#pathB+1] = objB
- if (not sameFunction and objA.type == 'function') or objB.type == 'main' then
- break
+ if objB then
+ pathB[#pathB+1] = objB
+ if (not sameFunction and objB.type == 'function') or objB.type == 'main' then
+ break
+ end
end
end
-- pathA: {1, 2, 3, 4, 5}
@@ -1108,7 +1151,7 @@ function m.getPath(a, b, sameFunction)
end
end
if not start then
- return nil
+ return false
end
-- pathA: { 1, 2, 3}
-- pathB: {5, 6, 2, 3}
@@ -1201,6 +1244,15 @@ function m.isInString(ast, position)
end)
end
+function m.isInComment(ast, offset)
+ for _, com in ipairs(ast.state.comms) do
+ if offset >= com.start and offset <= com.finish then
+ return true
+ end
+ end
+ return false
+end
+
function m.isOOP(source)
if source.type == 'setmethod'
or source.type == 'getmethod' then
@@ -1221,6 +1273,7 @@ local basicTypeMap = {
['false'] = true,
['nil'] = true,
['boolean'] = true,
+ ['integer'] = true,
['number'] = true,
['string'] = true,
['table'] = true,
@@ -1235,4 +1288,10 @@ function m.isBasicType(str)
return basicTypeMap[str] == true
end
+---@param source parser.object
+---@return boolean
+function m.isBlockType(source)
+ return blockTypes[source.type] == true
+end
+
return m
diff --git a/script/parser/init.lua b/script/parser/init.lua
index 219f8900..bc004f77 100644
--- a/script/parser/init.lua
+++ b/script/parser/init.lua
@@ -1,13 +1,8 @@
local api = {
- grammar = require 'parser.grammar',
- parse = require 'parser.parse',
compile = require 'parser.compile',
- split = require 'parser.split',
- calcline = require 'parser.calcline',
lines = require 'parser.lines',
guide = require 'parser.guide',
luadoc = require 'parser.luadoc',
- tokens = require 'parser.tokens',
}
return api
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index d8e31950..51161565 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -1,22 +1,28 @@
local m = require 'lpeglabel'
local re = require 'parser.relabel'
local guide = require 'parser.guide'
-local parser = require 'parser.newparser'
+local compile = require 'parser.compile'
local util = require 'utility'
local TokenTypes, TokenStarts, TokenFinishs, TokenContents, TokenMarks
-local Ci, Offset, pushWarning, NextComment, Lines
+---@type integer
+local Ci
+---@type integer
+local Offset
+local pushWarning, NextComment, Lines
local parseType, parseTypeUnit
---@type any
local Parser = re.compile([[
Main <- (Token / Sp)*
Sp <- %s+
X16 <- [a-fA-F0-9]
-Token <- Integer / Name / String / Symbol
+Token <- Integer / Name / String / Code / Symbol
Name <- ({} {%name} {})
-> Name
-Integer <- ({} {[0-9]+} !'.' {})
+Integer <- ({} {'-'? [0-9]+} !'.' {})
-> Integer
+Code <- ({} '`' { (!'`' .)*} '`' {})
+ -> Code
String <- ({} StringDef {})
-> String
StringDef <- {'"'}
@@ -48,7 +54,7 @@ EChar <- 'a' -> ea
/ ([0-9] [0-9]? [0-9]?) -> Char10
/ ('u{' {X16*} '}') -> CharUtf8
Symbol <- ({} {
- [:|,<>()?+#`{}]
+ [:|,;<>()?+#{}]
/ '[]'
/ '...'
/ '['
@@ -67,6 +73,7 @@ Symbol <- ({} {
ev = '\v',
name = (m.R('az', 'AZ', '09', '\x80\xff') + m.S('_')) * (m.R('az', 'AZ', '__', '09', '\x80\xff') + m.S('_.*-'))^0,
Char10 = function (char)
+ ---@type integer?
char = tonumber(char)
if not char or char < 0 or char > 255 then
return ''
@@ -114,6 +121,13 @@ Symbol <- ({} {
TokenFinishs[Ci] = finish - 1
TokenContents[Ci] = math.tointeger(content)
end,
+ Code = function (start, content, finish)
+ Ci = Ci + 1
+ TokenTypes[Ci] = 'code'
+ TokenStarts[Ci] = start
+ TokenFinishs[Ci] = finish - 1
+ TokenContents[Ci] = content
+ end,
Symbol = function (start, content, finish)
Ci = Ci + 1
TokenTypes[Ci] = 'symbol'
@@ -128,10 +142,12 @@ Symbol <- ({} {
---@field signs parser.object[]
---@field originalComment parser.object
---@field as? parser.object
-
-local function trim(str)
- return str:match '^%s*(%S+)%s*$'
-end
+---@field touch? integer
+---@field module? string
+---@field async? boolean
+---@field versions? table[]
+---@field names? parser.object[]
+---@field path? string
local function parseTokens(text, offset)
Ci = 0
@@ -149,11 +165,13 @@ local function peekToken()
return TokenTypes[Ci+1], TokenContents[Ci+1]
end
+---@return string? tokenType
+---@return string? tokenContent
local function nextToken()
Ci = Ci + 1
if not TokenTypes[Ci] then
Ci = Ci - 1
- return nil
+ return nil, nil
end
return TokenTypes[Ci], TokenContents[Ci]
end
@@ -171,6 +189,7 @@ local function getStart()
return TokenStarts[Ci] + Offset
end
+---@return integer
local function getFinish()
if Ci == 0 then
return Offset
@@ -273,6 +292,11 @@ local function parseTable(parent)
}
do
+ local needCloseParen
+ if checkToken('symbol', '(', 1) then
+ nextToken()
+ needCloseParen = true
+ end
field.name = parseName('doc.field.name', field)
or parseIndexField('doc.field.name', field)
if not field.name then
@@ -299,10 +323,14 @@ local function parseTable(parent)
break
end
field.finish = getFinish()
+ if needCloseParen then
+ nextSymbolOrError ')'
+ end
end
typeUnit.fields[#typeUnit.fields+1] = field
- if checkToken('symbol', ',', 1) then
+ if checkToken('symbol', ',', 1)
+ or checkToken('symbol', ';', 1) then
nextToken()
else
nextSymbolOrError('}')
@@ -412,11 +440,35 @@ local function parseTypeUnitFunction(parent)
end
if checkToken('symbol', ':', 1) then
nextToken()
+ local needCloseParen
+ if checkToken('symbol', '(', 1) then
+ nextToken()
+ needCloseParen = true
+ end
while true do
+ local name
+ try(function ()
+ local returnName = parseName('doc.return.name', typeUnit)
+ or parseDots('doc.return.name', typeUnit)
+ if not returnName then
+ return false
+ end
+ if checkToken('symbol', ':', 1) then
+ nextToken()
+ name = returnName
+ return true
+ end
+ if returnName[1] == '...' then
+ name = returnName
+ return false
+ end
+ return false
+ end)
local rtn = parseType(typeUnit)
if not rtn then
break
end
+ rtn.name = name
if checkToken('symbol', '?', 1) then
nextToken()
rtn.optional = true
@@ -428,6 +480,9 @@ local function parseTypeUnitFunction(parent)
break
end
end
+ if needCloseParen then
+ nextSymbolOrError ')'
+ end
end
typeUnit.finish = getFinish()
return typeUnit
@@ -534,6 +589,22 @@ local function parseString(parent)
return str
end
+local function parseCode(parent)
+ local tp, content = peekToken()
+ if not tp or tp ~= 'code' then
+ return nil
+ end
+ nextToken()
+ local code = {
+ type = 'doc.type.code',
+ start = getStart(),
+ finish = getFinish(),
+ parent = parent,
+ [1] = content,
+ }
+ return code
+end
+
local function parseInteger(parent)
local tp, content = peekToken()
if not tp or tp ~= 'integer' then
@@ -584,22 +655,18 @@ function parseTypeUnit(parent)
local result = parseFunction(parent)
or parseTable(parent)
or parseString(parent)
+ or parseCode(parent)
or parseInteger(parent)
or parseBoolean(parent)
- or parseDots('doc.type.name', parent)
or parseParen(parent)
if not result then
- local literal = checkToken('symbol', '`', 1)
- if literal then
- nextToken()
- end
result = parseName('doc.type.name', parent)
+ or parseDots('doc.type.name', parent)
if not result then
return nil
end
- if literal then
- result.literal = true
- nextSymbolOrError '`'
+ if result[1] == '...' then
+ result[1] = 'unknown'
end
end
while true do
@@ -749,8 +816,9 @@ local docSwitch = util.switch()
: case 'class'
: call(function ()
local result = {
- type = 'doc.class',
- fields = {},
+ type = 'doc.class',
+ fields = {},
+ operators = {},
}
result.class = parseName('doc.class.name', result)
if not result.class then
@@ -793,7 +861,20 @@ local docSwitch = util.switch()
end)
: case 'type'
: call(function ()
- return parseType()
+ local first = parseType()
+ if not first then
+ return nil
+ end
+ local rests
+ while checkToken('symbol', ',', 1) do
+ nextToken()
+ local rest = parseType()
+ if not rests then
+ rests = {}
+ end
+ rests[#rests+1] = rest
+ end
+ return first, rests
end)
: case 'alias'
: call(function ()
@@ -864,6 +945,10 @@ local docSwitch = util.switch()
returns = {},
}
while true do
+ local dots = parseDots('doc.return.name')
+ if dots then
+ Ci = Ci - 1
+ end
local docType = parseType(result)
if not docType then
break
@@ -875,7 +960,13 @@ local docSwitch = util.switch()
nextToken()
docType.optional = true
end
- docType.name = parseName('doc.return.name', docType)
+ if dots then
+ docType.name = dots
+ dots.parent = docType
+ else
+ docType.name = parseName('doc.return.name', docType)
+ or parseDots('doc.return.name', docType)
+ end
result.returns[#result.returns+1] = docType
if not checkToken('symbol', ',', 1) then
break
@@ -1250,8 +1341,7 @@ local docSwitch = util.switch()
if checkToken('symbol', '?', 1) then
block.optional = true
nextToken()
- block.start = block.start or getStart()
- block.finish = block.finish
+ block.finish = getFinish()
else
block.extends = parseType(block)
if block.extends then
@@ -1263,6 +1353,7 @@ local docSwitch = util.switch()
if block.optional or block.extends then
result.casts[#result.casts+1] = block
end
+ result.finish = block.finish
if checkToken('symbol', ',', 1) then
nextToken()
@@ -1273,8 +1364,87 @@ local docSwitch = util.switch()
return result
end)
+ : case 'operator'
+ : call(function ()
+ local result = {
+ type = 'doc.operator',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+
+ local op = parseName('doc.operator.name', result)
+ if not op then
+ pushWarning {
+ type = 'LUADOC_MISS_OPERATOR_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.op = op
+ result.finish = op.finish
+
+ if checkToken('symbol', '(', 1) then
+ nextToken()
+ local exp = parseType(result)
+ if exp then
+ result.exp = exp
+ result.finish = exp.finish
+ end
+ nextSymbolOrError ')'
+ end
+
+ nextSymbolOrError ':'
+
+ local ret = parseType(result)
+ if ret then
+ result.extends = ret
+ result.finish = ret.finish
+ end
-local function convertTokens()
+ return result
+ end)
+ : case 'source'
+ : call(function (doc)
+ local fullSource = doc:sub(#'source' + 1)
+ if not fullSource or fullSource == '' then
+ return
+ end
+ fullSource = util.trim(fullSource)
+ if fullSource == '' then
+ return
+ end
+ local source, line, char = fullSource:match('^(.-):?(%d*):?(%d*)$')
+ source = source or fullSource
+ line = tonumber(line) or 1
+ char = tonumber(char) or 0
+ local result = {
+ type = 'doc.source',
+ start = getStart(),
+ finish = getFinish(),
+ path = source,
+ line = line,
+ char = char,
+ }
+ return result
+ end)
+ : case 'enum'
+ : call(function ()
+ local name = parseName('doc.enum.name')
+ if not name then
+ return nil
+ end
+ local result = {
+ type = 'doc.enum',
+ start = name.start,
+ finish = name.finish,
+ enum = name,
+ }
+ name.parent = result
+ return result
+ end)
+
+local function convertTokens(doc)
local tp, text = nextToken()
if not tp then
return
@@ -1287,7 +1457,7 @@ local function convertTokens()
}
return nil
end
- return docSwitch(text)
+ return docSwitch(text, doc)
end
local function trimTailComment(text)
@@ -1302,7 +1472,7 @@ local function trimTailComment(text)
comment = text:sub(3)
end
if comment:find '^%s*[\'"[]' then
- local state = parser(comment:gsub('^%s+', ''), 'String')
+ local state = compile(comment:gsub('^%s+', ''), 'String')
if state and state.ast then
comment = state.ast[1]
end
@@ -1327,10 +1497,17 @@ local function buildLuaDoc(comment)
local doc = text:sub(startPos)
parseTokens(doc, comment.start + startPos)
- local result = convertTokens()
+ local result, rests = convertTokens(doc)
if result then
result.range = comment.finish
- local cstart = text:find('%S', (result.firstFinish or result.finish) - comment.start)
+ local finish = result.firstFinish or result.finish
+ if rests then
+ for _, rest in ipairs(rests) do
+ rest.range = comment.finish
+ finish = rest.firstFinish or result.finish
+ end
+ end
+ local cstart = text:find('%S', finish - comment.start)
if cstart and cstart < comment.finish then
result.comment = {
type = 'doc.tailcomment',
@@ -1339,11 +1516,16 @@ local function buildLuaDoc(comment)
parent = result,
text = trimTailComment(text:sub(cstart)),
}
+ if rests then
+ for _, rest in ipairs(rests) do
+ rest.comment = result.comment
+ end
+ end
end
end
if result then
- return result
+ return result, rests
end
return {
@@ -1355,37 +1537,54 @@ local function buildLuaDoc(comment)
}
end
-local function isTailComment(text, binded)
- local lastDoc = binded[#binded]
- local left = lastDoc.originalComment.start
+local function isTailComment(text, doc)
+ if not doc then
+ return false
+ end
+ local left = doc.originalComment.start
local row, col = guide.rowColOf(left)
local lineStart = Lines[row] or 0
local hasCodeBefore = text:sub(lineStart, lineStart + col):find '[%w_]'
return hasCodeBefore
end
-local function isNextLine(binded, doc)
- if not binded then
+local function isContinuedDoc(lastDoc, nextDoc)
+ if not nextDoc then
return false
end
- local lastDoc = binded[#binded]
+ if nextDoc.type == 'doc.diagnostic' then
+ return true
+ end
if lastDoc.type == 'doc.type'
- or lastDoc.type == 'doc.module' then
- return false
+ or lastDoc.type == 'doc.module'
+ or lastDoc.type == 'doc.enum' then
+ if nextDoc.type ~= 'doc.comment' then
+ return false
+ end
end
if lastDoc.type == 'doc.class'
- or lastDoc.type == 'doc.field' then
- if doc.type ~= 'doc.field'
- and doc.type ~= 'doc.comment'
- and doc.type ~= 'doc.overload' then
+ or lastDoc.type == 'doc.field'
+ or lastDoc.type == 'doc.operator' then
+ if nextDoc.type ~= 'doc.field'
+ and nextDoc.type ~= 'doc.operator'
+ and nextDoc.type ~= 'doc.comment'
+ and nextDoc.type ~= 'doc.overload'
+ and nextDoc.type ~= 'doc.source' then
return false
end
end
- if doc.type == 'doc.cast' then
+ if nextDoc.type == 'doc.cast' then
+ return false
+ end
+ return true
+end
+
+local function isNextLine(lastDoc, nextDoc)
+ if not nextDoc then
return false
end
local lastRow = guide.rowColOf(lastDoc.finish)
- local newRow = guide.rowColOf(doc.start)
+ local newRow = guide.rowColOf(nextDoc.start)
return newRow - lastRow == 1
end
@@ -1408,6 +1607,7 @@ local function bindGeneric(binded)
end
end
if doc.type == 'doc.param'
+ or doc.type == 'doc.vararg'
or doc.type == 'doc.return'
or doc.type == 'doc.type'
or doc.type == 'doc.class'
@@ -1418,11 +1618,115 @@ local function bindGeneric(binded)
src.type = 'doc.generic.name'
end
end)
+ guide.eachSourceType(doc, 'doc.type.code', function (src)
+ local name = src[1]
+ if generics[name] then
+ src.type = 'doc.generic.name'
+ src.literal = true
+ end
+ end)
+ end
+ end
+end
+
+local function bindDoc(source, binded)
+ local isParam = source.type == 'self'
+ or source.type == 'local'
+ and (source.parent.type == 'funcargs'
+ or ( source.parent.type == 'in'
+ and source.finish <= source.parent.keys.finish
+ )
+ )
+ local ok = false
+ for _, doc in ipairs(binded) do
+ if doc.bindSource then
+ goto CONTINUE
+ end
+ if doc.type == 'doc.class'
+ or doc.type == 'doc.deprecated'
+ or doc.type == 'doc.version'
+ or doc.type == 'doc.module'
+ or doc.type == 'doc.source' then
+ if source.type == 'function'
+ or isParam then
+ goto CONTINUE
+ end
+ elseif doc.type == 'doc.type' then
+ if source.type == 'function'
+ or isParam
+ or source._bindedDocType then
+ goto CONTINUE
+ end
+ source._bindedDocType = true
+ elseif doc.type == 'doc.overload' then
+ if not source.bindDocs then
+ source.bindDocs = {}
+ end
+ source.bindDocs[#source.bindDocs+1] = doc
+ if source.type ~= 'function' then
+ doc.bindSource = source
+ end
+ elseif doc.type == 'doc.param' then
+ local suc
+ if isParam
+ and doc.param[1] == source[1] then
+ suc = true
+ elseif source.type == '...'
+ and doc.param[1] == '...' then
+ suc = true
+ elseif source.type == 'self'
+ and doc.param[1] == 'self' then
+ suc = true
+ end
+ if source.type == 'function' then
+ if not source.bindDocs then
+ source.bindDocs = {}
+ end
+ source.bindDocs[#source.bindDocs+1] = doc
+ end
+
+ if not suc then
+ goto CONTINUE
+ end
+ elseif doc.type == 'doc.vararg' then
+ if source.type ~= '...' then
+ goto CONTINUE
+ end
+ elseif doc.type == 'doc.return'
+ or doc.type == 'doc.generic'
+ or doc.type == 'doc.async'
+ or doc.type == 'doc.nodiscard' then
+ if source.type ~= 'function' then
+ goto CONTINUE
+ end
+ elseif doc.type == 'doc.enum' then
+ if source.type == 'table' then
+ goto OK
+ end
+ if source.value and source.value.type == 'table' then
+ if not source.value.bindDocs then
+ source.value.bindDocs = {}
+ end
+ source.value.bindDocs[#source.value.bindDocs+1] = doc
+ doc.bindSource = source.value
+ end
+ goto CONTINUE
+ elseif doc.type ~= 'doc.comment' then
+ goto CONTINUE
end
+ ::OK::
+ if not source.bindDocs then
+ source.bindDocs = {}
+ end
+ source.bindDocs[#source.bindDocs+1] = doc
+ doc.bindSource = source
+ ok = true
+ ::CONTINUE::
end
+ return ok
end
-local function bindDocsBetween(sources, binded, bindSources, start, finish)
+local function bindDocsBetween(sources, binded, start, finish)
-- 用二分法找到第一个
local max = #sources
local index
@@ -1445,6 +1749,7 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish)
end
end
+ local ok = false
-- 从前往后进行绑定
for i = index, max do
local src = sources[i]
@@ -1452,12 +1757,6 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish)
if src.start >= finish then
break
end
- -- 遇到table后中断,处理以下情况:
- -- ---@type AAA
- -- local t = {x = 1, y = 2}
- if src.type == 'table' then
- break
- end
if src.start >= start then
if src.type == 'local'
or src.type == 'self'
@@ -1468,13 +1767,18 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish)
or src.type == 'setfield'
or src.type == 'setindex'
or src.type == 'setmethod'
- or src.type == 'function' then
- src.bindDocs = binded
- bindSources[#bindSources+1] = src
+ or src.type == 'function'
+ or src.type == 'table'
+ or src.type == '...' then
+ if bindDoc(src, binded) then
+ ok = true
+ end
end
end
end
end
+
+ return ok
end
local function bindReturnIndex(binded)
@@ -1489,25 +1793,64 @@ local function bindReturnIndex(binded)
end
end
-local function bindClassAndFields(binded)
+local function bindCommentsToDoc(doc, comments)
+ doc.bindComments = comments
+ for _, comment in ipairs(comments) do
+ comment.bindSource = doc
+ end
+end
+
+local function bindCommentsAndFields(binded)
local class
+ local comments = {}
+ local source
for _, doc in ipairs(binded) do
if doc.type == 'doc.class' then
-- 多个class连续写在一起,只有最后一个class可以绑定source
if class then
- class.bindSources = nil
+ class.bindSource = nil
+ end
+ if source then
+ doc.source = source
+ source.bindSource = doc
end
class = doc
+ bindCommentsToDoc(doc, comments)
+ comments = {}
elseif doc.type == 'doc.field' then
if class then
class.fields[#class.fields+1] = doc
doc.class = class
end
- end
+ if source then
+ doc.source = source
+ source.bindSource = doc
+ end
+ bindCommentsToDoc(doc, comments)
+ comments = {}
+ elseif doc.type == 'doc.operator' then
+ if class then
+ class.operators[#class.operators+1] = doc
+ doc.class = class
+ end
+ bindCommentsToDoc(doc, comments)
+ comments = {}
+ elseif doc.type == 'doc.alias'
+ or doc.type == 'doc.enum' then
+ bindCommentsToDoc(doc, comments)
+ comments = {}
+ elseif doc.type == 'doc.comment' then
+ comments[#comments+1] = doc
+ elseif doc.type == 'doc.source' then
+ source = doc
+ goto CONTINUE
+ end
+ source = nil
+ ::CONTINUE::
end
end
-local function bindDoc(sources, binded)
+local function bindDocWithSources(sources, binded)
if not binded then
return
end
@@ -1515,19 +1858,17 @@ local function bindDoc(sources, binded)
if not lastDoc then
return
end
- local bindSources = {}
for _, doc in ipairs(binded) do
doc.bindGroup = binded
- doc.bindSources = bindSources
end
bindGeneric(binded)
+ bindCommentsAndFields(binded)
+ bindReturnIndex(binded)
local row = guide.rowColOf(lastDoc.finish)
- bindDocsBetween(sources, binded, bindSources, guide.positionOf(row, 0), lastDoc.start)
- if #bindSources == 0 then
- bindDocsBetween(sources, binded, bindSources, guide.positionOf(row + 1, 0), guide.positionOf(row + 2, 0))
+ local suc = bindDocsBetween(sources, binded, guide.positionOf(row, 0), lastDoc.start)
+ if not suc then
+ bindDocsBetween(sources, binded, guide.positionOf(row + 1, 0), guide.positionOf(row + 2, 0))
end
- bindReturnIndex(binded)
- bindClassAndFields(binded)
end
local bindDocAccept = {
@@ -1547,19 +1888,44 @@ local function bindDocs(state)
return a.start < b.start
end)
local binded
- for _, doc in ipairs(state.ast.docs) do
- if not isNextLine(binded, doc) then
- bindDoc(sources, binded)
+ for i, doc in ipairs(state.ast.docs) do
+ if not binded then
binded = {}
state.ast.docs.groups[#state.ast.docs.groups+1] = binded
end
binded[#binded+1] = doc
- if isTailComment(text, binded) then
- bindDoc(sources, binded)
+ if isTailComment(text, doc) then
+ bindDocWithSources(sources, binded)
binded = nil
+ else
+ local nextDoc = state.ast.docs[i+1]
+ if not isNextLine(doc, nextDoc) then
+ bindDocWithSources(sources, binded)
+ binded = nil
+ end
+ if not isContinuedDoc(doc, nextDoc)
+ and not isTailComment(text, nextDoc) then
+ bindDocWithSources(sources, binded)
+ binded = nil
+ end
+ end
+ end
+end
+
+local function findTouch(state, doc)
+ local text = state.lua
+ local pos = guide.positionToOffset(state, doc.originalComment.start)
+ for i = pos - 2, 1, -1 do
+ local c = text:sub(i, i)
+ if c == '\r'
+ or c == '\n' then
+ break
+ elseif c ~= ' '
+ and c ~= '\t' then
+ doc.touch = guide.offsetToPosition(state, i)
+ break
end
end
- bindDoc(sources, binded)
end
return function (state)
@@ -1589,25 +1955,40 @@ return function (state)
return comment
end
+ local function insertDoc(doc, comment)
+ ast.docs[#ast.docs+1] = doc
+ doc.parent = ast.docs
+ if ast.start > doc.start then
+ ast.start = doc.start
+ end
+ if ast.finish < doc.finish then
+ ast.finish = doc.finish
+ end
+ doc.originalComment = comment
+ if comment.type == 'comment.long' then
+ findTouch(state, doc)
+ end
+ end
+
while true do
local comment = NextComment()
if not comment then
break
end
- local doc = buildLuaDoc(comment)
+ local doc, rests = buildLuaDoc(comment)
if doc then
- ast.docs[#ast.docs+1] = doc
- doc.parent = ast.docs
- if ast.start > doc.start then
- ast.start = doc.start
- end
- if ast.finish < doc.finish then
- ast.finish = doc.finish
+ insertDoc(doc, comment)
+ if rests then
+ for _, rest in ipairs(rests) do
+ insertDoc(rest, comment)
+ end
end
- doc.originalComment = comment
end
end
+ ast.docs.start = ast.start
+ ast.docs.finish = ast.finish
+
if #ast.docs == 0 then
return
end
diff --git a/script/parser/newparser.lua b/script/parser/newparser.lua
deleted file mode 100644
index 630c12c2..00000000
--- a/script/parser/newparser.lua
+++ /dev/null
@@ -1,3855 +0,0 @@
-local tokens = require 'parser.tokens'
-local guide = require 'parser.guide'
-
-local sbyte = string.byte
-local sfind = string.find
-local smatch = string.match
-local sgsub = string.gsub
-local ssub = string.sub
-local schar = string.char
-local supper = string.upper
-local uchar = utf8.char
-local tconcat = table.concat
-local tinsert = table.insert
-local tointeger = math.tointeger
-local mtype = math.type
-local tonumber = tonumber
-local maxinteger = math.maxinteger
-local assert = assert
-local next = next
-
-_ENV = nil
-
----@alias parser.position integer
-
----@param str string
----@return table<integer, boolean>
-local function stringToCharMap(str)
- local map = {}
- local pos = 1
- while pos <= #str do
- local byte = sbyte(str, pos, pos)
- map[schar(byte)] = true
- pos = pos + 1
- if ssub(str, pos, pos) == '-'
- and pos < #str then
- pos = pos + 1
- local byte2 = sbyte(str, pos, pos)
- assert(byte < byte2)
- for b = byte + 1, byte2 do
- map[schar(b)] = true
- end
- pos = pos + 1
- end
- end
- return map
-end
-
-local CharMapNumber = stringToCharMap '0-9'
-local CharMapN16 = stringToCharMap 'xX'
-local CharMapN2 = stringToCharMap 'bB'
-local CharMapE10 = stringToCharMap 'eE'
-local CharMapE16 = stringToCharMap 'pP'
-local CharMapSign = stringToCharMap '+-'
-local CharMapSB = stringToCharMap 'ao|~&=<>.*/%^+-'
-local CharMapSU = stringToCharMap 'n#~!-'
-local CharMapSimple = stringToCharMap '.:([\'"{'
-local CharMapStrSH = stringToCharMap '\'"`'
-local CharMapStrLH = stringToCharMap '['
-local CharMapTSep = stringToCharMap ',;'
-local CharMapWord = stringToCharMap '_a-zA-Z\x80-\xff'
-
-local EscMap = {
- ['a'] = '\a',
- ['b'] = '\b',
- ['f'] = '\f',
- ['n'] = '\n',
- ['r'] = '\r',
- ['t'] = '\t',
- ['v'] = '\v',
- ['\\'] = '\\',
- ['\''] = '\'',
- ['\"'] = '\"',
-}
-
-local NLMap = {
- ['\n'] = true,
- ['\r'] = true,
- ['\r\n'] = true,
-}
-
-local LineMulti = 10000
-
--- goto 单独处理
-local KeyWord = {
- ['and'] = true,
- ['break'] = true,
- ['do'] = true,
- ['else'] = true,
- ['elseif'] = true,
- ['end'] = true,
- ['false'] = true,
- ['for'] = true,
- ['function'] = true,
- ['if'] = true,
- ['in'] = true,
- ['local'] = true,
- ['nil'] = true,
- ['not'] = true,
- ['or'] = true,
- ['repeat'] = true,
- ['return'] = true,
- ['then'] = true,
- ['true'] = true,
- ['until'] = true,
- ['while'] = true,
-}
-
-local Specials = {
- ['_G'] = true,
- ['rawset'] = true,
- ['rawget'] = true,
- ['setmetatable'] = true,
- ['require'] = true,
- ['dofile'] = true,
- ['loadfile'] = true,
- ['pcall'] = true,
- ['xpcall'] = true,
- ['pairs'] = true,
- ['ipairs'] = true,
- ['assert'] = true,
-}
-
-local UnarySymbol = {
- ['not'] = 11,
- ['#'] = 11,
- ['~'] = 11,
- ['-'] = 11,
-}
-
-local BinarySymbol = {
- ['or'] = 1,
- ['and'] = 2,
- ['<='] = 3,
- ['>='] = 3,
- ['<'] = 3,
- ['>'] = 3,
- ['~='] = 3,
- ['=='] = 3,
- ['|'] = 4,
- ['~'] = 5,
- ['&'] = 6,
- ['<<'] = 7,
- ['>>'] = 7,
- ['..'] = 8,
- ['+'] = 9,
- ['-'] = 9,
- ['*'] = 10,
- ['//'] = 10,
- ['/'] = 10,
- ['%'] = 10,
- ['^'] = 12,
-}
-
-local BinaryAlias = {
- ['&&'] = 'and',
- ['||'] = 'or',
- ['!='] = '~=',
-}
-
-local BinaryActionAlias = {
- ['='] = '==',
-}
-
-local UnaryAlias = {
- ['!'] = 'not',
-}
-
-local SymbolForward = {
- [01] = true,
- [02] = true,
- [03] = true,
- [04] = true,
- [05] = true,
- [06] = true,
- [07] = true,
- [08] = false,
- [09] = true,
- [10] = true,
- [11] = true,
- [12] = false,
-}
-
-local GetToSetMap = {
- ['getglobal'] = 'setglobal',
- ['getlocal'] = 'setlocal',
- ['getfield'] = 'setfield',
- ['getindex'] = 'setindex',
- ['getmethod'] = 'setmethod',
-}
-
-local ChunkFinishMap = {
- ['end'] = true,
- ['else'] = true,
- ['elseif'] = true,
- ['in'] = true,
- ['then'] = true,
- ['until'] = true,
- [';'] = true,
- [']'] = true,
- [')'] = true,
- ['}'] = true,
-}
-
-local ChunkStartMap = {
- ['do'] = true,
- ['else'] = true,
- ['elseif'] = true,
- ['for'] = true,
- ['function'] = true,
- ['if'] = true,
- ['local'] = true,
- ['repeat'] = true,
- ['return'] = true,
- ['then'] = true,
- ['until'] = true,
- ['while'] = true,
-}
-
-local ListFinishMap = {
- ['end'] = true,
- ['else'] = true,
- ['elseif'] = true,
- ['in'] = true,
- ['then'] = true,
- ['do'] = true,
- ['until'] = true,
- ['for'] = true,
- ['if'] = true,
- ['local'] = true,
- ['repeat'] = true,
- ['return'] = true,
- ['while'] = true,
-}
-
-local State, Lua, Line, LineOffset, Chunk, Tokens, Index, LastTokenFinish, Mode, LocalCount
-
-local LocalLimit = 200
-
-local parseExp, parseAction
-
-local pushError
-
-local function addSpecial(name, obj)
- if not State.specials then
- State.specials = {}
- end
- if not State.specials[name] then
- State.specials[name] = {}
- end
- State.specials[name][#State.specials[name]+1] = obj
- obj.special = name
-end
-
----@param offset integer
----@param leftOrRight '"left"'|'"right"'
-local function getPosition(offset, leftOrRight)
- if not offset or offset > #Lua then
- return LineMulti * Line + #Lua - LineOffset + 1
- end
- if leftOrRight == 'left' then
- return LineMulti * Line + offset - LineOffset
- else
- return LineMulti * Line + offset - LineOffset + 1
- end
-end
-
----@return string word
----@return parser.position startPosition
----@return parser.position finishPosition
----@return integer newOffset
-local function peekWord()
- local word = Tokens[Index + 1]
- if not word then
- return nil
- end
- if not CharMapWord[ssub(word, 1, 1)] then
- return nil
- end
- local startPos = getPosition(Tokens[Index] , 'left')
- local finishPos = getPosition(Tokens[Index] + #word - 1, 'right')
- return word, startPos, finishPos
-end
-
-local function lastRightPosition()
- if Index < 2 then
- return 0
- end
- local token = Tokens[Index - 1]
- if NLMap[token] then
- return LastTokenFinish
- elseif token then
- return getPosition(Tokens[Index - 2] + #token - 1, 'right')
- else
- return getPosition(#Lua, 'right')
- end
-end
-
-local function missSymbol(symbol, start, finish)
- pushError {
- type = 'MISS_SYMBOL',
- start = start or lastRightPosition(),
- finish = finish or start or lastRightPosition(),
- info = {
- symbol = symbol,
- }
- }
-end
-
-local function missExp()
- pushError {
- type = 'MISS_EXP',
- start = lastRightPosition(),
- finish = lastRightPosition(),
- }
-end
-
-local function missName(pos)
- pushError {
- type = 'MISS_NAME',
- start = pos or lastRightPosition(),
- finish = pos or lastRightPosition(),
- }
-end
-
-local function missEnd(relatedStart, relatedFinish)
- pushError {
- type = 'MISS_SYMBOL',
- start = lastRightPosition(),
- finish = lastRightPosition(),
- info = {
- symbol = 'end',
- related = {
- {
- start = relatedStart,
- finish = relatedFinish,
- }
- }
- }
- }
- pushError {
- type = 'MISS_END',
- start = relatedStart,
- finish = relatedFinish,
- }
-end
-
-local function unknownSymbol(start, finish, word)
- local token = word or Tokens[Index + 1]
- if not token then
- return false
- end
- pushError {
- type = 'UNKNOWN_SYMBOL',
- start = start or getPosition(Tokens[Index], 'left'),
- finish = finish or getPosition(Tokens[Index] + #token - 1, 'right'),
- info = {
- symbol = token,
- }
- }
- return true
-end
-
-local function skipUnknownSymbol(stopSymbol)
- if unknownSymbol() then
- Index = Index + 2
- return true
- end
- return false
-end
-
-local function skipNL()
- local token = Tokens[Index + 1]
- if NLMap[token] then
- if Index >= 2 and not NLMap[Tokens[Index - 1]] then
- LastTokenFinish = getPosition(Tokens[Index - 2] + #Tokens[Index - 1] - 1, 'right')
- end
- Line = Line + 1
- LineOffset = Tokens[Index] + #token
- Index = Index + 2
- State.lines[Line] = LineOffset
- return true
- end
- return false
-end
-
-local function getSavePoint()
- local index = Index
- local line = Line
- local lineOffset = LineOffset
- local errs = State.errs
- local errCount = #errs
- return function ()
- Index = index
- Line = line
- LineOffset = lineOffset
- for i = errCount + 1, #errs do
- errs[i] = nil
- end
- end
-end
-
-local function fastForwardToken(offset)
- while true do
- local myOffset = Tokens[Index]
- if not myOffset
- or myOffset >= offset then
- break
- end
- local token = Tokens[Index + 1]
- if NLMap[token] then
- Line = Line + 1
- LineOffset = Tokens[Index] + #token
- State.lines[Line] = LineOffset
- end
- Index = Index + 2
- end
-end
-
-local function resolveLongString(finishMark)
- skipNL()
- local miss
- local start = Tokens[Index]
- local finishOffset = sfind(Lua, finishMark, start, true)
- if not finishOffset then
- finishOffset = #Lua + 1
- miss = true
- end
- local stringResult = start and ssub(Lua, start, finishOffset - 1) or ''
- local lastLN = stringResult:find '[\r\n][^\r\n]*$'
- if lastLN then
- local result = stringResult
- : gsub('\r\n?', '\n')
- stringResult = result
- end
- fastForwardToken(finishOffset + #finishMark)
- if miss then
- local pos = getPosition(finishOffset - 1, 'right')
- pushError {
- type = 'MISS_SYMBOL',
- start = pos,
- finish = pos,
- info = {
- symbol = finishMark,
- },
- fix = {
- title = 'ADD_LSTRING_END',
- {
- start = pos,
- finish = pos,
- text = finishMark,
- }
- },
- }
- end
- return stringResult, getPosition(finishOffset + #finishMark - 1, 'right')
-end
-
-local function parseLongString()
- local start, finish, mark = sfind(Lua, '^(%[%=*%[)', Tokens[Index])
- if not mark then
- return nil
- end
- fastForwardToken(finish + 1)
- local startPos = getPosition(start, 'left')
- local finishMark = sgsub(mark, '%[', ']')
- local stringResult, finishPos = resolveLongString(finishMark)
- return {
- type = 'string',
- start = startPos,
- finish = finishPos,
- [1] = stringResult,
- [2] = mark,
- }
-end
-
-local function pushCommentHeadError(left)
- if State.options.nonstandardSymbol['//'] then
- return
- end
- pushError {
- type = 'ERR_COMMENT_PREFIX',
- start = left,
- finish = left + 2,
- fix = {
- title = 'FIX_COMMENT_PREFIX',
- {
- start = left,
- finish = left + 2,
- text = '--',
- },
- }
- }
-end
-
-local function pushLongCommentError(left, right)
- if State.options.nonstandardSymbol['/**/'] then
- return
- end
- pushError {
- type = 'ERR_C_LONG_COMMENT',
- start = left,
- finish = right,
- fix = {
- title = 'FIX_C_LONG_COMMENT',
- {
- start = left,
- finish = left + 2,
- text = '--[[',
- },
- {
- start = right - 2,
- finish = right,
- text = '--]]'
- },
- }
- }
-end
-
-local function skipComment(isAction)
- local token = Tokens[Index + 1]
- if token == '--'
- or (
- token == '//'
- and (
- isAction
- or State.options.nonstandardSymbol['//']
- )
- ) then
- local start = Tokens[Index]
- local left = getPosition(start, 'left')
- local chead = false
- if token == '//' then
- chead = true
- pushCommentHeadError(left)
- end
- Index = Index + 2
- local longComment = start + 2 == Tokens[Index] and parseLongString()
- if longComment then
- longComment.type = 'comment.long'
- longComment.text = longComment[1]
- longComment.mark = longComment[2]
- longComment[1] = nil
- longComment[2] = nil
- State.comms[#State.comms+1] = longComment
- return true
- end
- while true do
- local nl = Tokens[Index + 1]
- if not nl or NLMap[nl] then
- break
- end
- Index = Index + 2
- end
- State.comms[#State.comms+1] = {
- type = chead and 'comment.cshort' or 'comment.short',
- start = left,
- finish = lastRightPosition(),
- text = ssub(Lua, start + 2, Tokens[Index] and (Tokens[Index] - 1) or #Lua),
- }
- return true
- end
- if token == '/*' then
- local start = Tokens[Index]
- local left = getPosition(start, 'left')
- Index = Index + 2
- local result, right = resolveLongString '*/'
- pushLongCommentError(left, right)
- State.comms[#State.comms+1] = {
- type = 'comment.long',
- start = left,
- finish = right,
- text = result,
- }
- return true
- end
- return false
-end
-
-local function skipSpace(isAction)
- repeat until not skipNL()
- and not skipComment(isAction)
-end
-
-local function expectAssign(isAction)
- local token = Tokens[Index + 1]
- if token == '=' then
- Index = Index + 2
- return true
- end
- if token == '==' then
- local left = getPosition(Tokens[Index], 'left')
- local right = getPosition(Tokens[Index] + #token - 1, 'right')
- pushError {
- type = 'ERR_ASSIGN_AS_EQ',
- start = left,
- finish = right,
- fix = {
- title = 'FIX_ASSIGN_AS_EQ',
- {
- start = left,
- finish = right,
- text = '=',
- }
- }
- }
- Index = Index + 2
- return true
- end
- if isAction then
- if token == '+='
- or token == '-='
- or token == '*='
- or token == '/=' then
- if not State.options.nonstandardSymbol[token] then
- unknownSymbol()
- end
- Index = Index + 2
- return true
- end
- end
- return false
-end
-
-local function parseLocalAttrs()
- local attrs
- while true do
- skipSpace()
- local token = Tokens[Index + 1]
- if token ~= '<' then
- break
- end
- if not attrs then
- attrs = {
- type = 'localattrs',
- }
- end
- local attr = {
- type = 'localattr',
- parent = attrs,
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index], 'right'),
- }
- attrs[#attrs+1] = attr
- Index = Index + 2
- skipSpace()
- local word, wstart, wfinish = peekWord()
- if word then
- attr[1] = word
- attr.finish = wfinish
- Index = Index + 2
- if word ~= 'const'
- and word ~= 'close' then
- pushError {
- type = 'UNKNOWN_ATTRIBUTE',
- start = wstart,
- finish = wfinish,
- }
- end
- else
- missName()
- end
- attr.finish = lastRightPosition()
- skipSpace()
- if Tokens[Index + 1] == '>' then
- attr.finish = getPosition(Tokens[Index], 'right')
- Index = Index + 2
- elseif Tokens[Index + 1] == '>=' then
- attr.finish = getPosition(Tokens[Index], 'right')
- pushError {
- type = 'MISS_SPACE_BETWEEN',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + 1, 'right'),
- }
- Index = Index + 2
- else
- missSymbol '>'
- end
- if State.version ~= 'Lua 5.4' then
- pushError {
- type = 'UNSUPPORT_SYMBOL',
- start = attr.start,
- finish = attr.finish,
- version = 'Lua 5.4',
- info = {
- version = State.version
- }
- }
- end
- end
- return attrs
-end
-
-local function createLocal(obj, attrs)
- obj.type = 'local'
- obj.effect = obj.finish
-
- if attrs then
- obj.attrs = attrs
- attrs.parent = obj
- end
-
- local chunk = Chunk[#Chunk]
- if chunk then
- local locals = chunk.locals
- if not locals then
- locals = {}
- chunk.locals = locals
- end
- locals[#locals+1] = obj
- LocalCount = LocalCount + 1
- if LocalCount > LocalLimit then
- pushError {
- type = 'LOCAL_LIMIT',
- start = obj.start,
- finish = obj.finish,
- }
- end
- end
- return obj
-end
-
-local function pushChunk(chunk)
- Chunk[#Chunk+1] = chunk
-end
-
-local function resolveLable(label, obj)
- if not label.ref then
- label.ref = {}
- end
- label.ref[#label.ref+1] = obj
- obj.node = label
-
- -- 如果有局部变量在 goto 与 label 之间声明,
- -- 并在 label 之后使用,则算作语法错误
-
- -- 如果 label 在 goto 之前声明,那么不会有中间声明的局部变量
- if obj.start > label.start then
- return
- end
-
- local block = guide.getBlock(obj)
- local locals = block and block.locals
- if not locals then
- return
- end
-
- for i = 1, #locals do
- local loc = locals[i]
- -- 检查局部变量声明位置为 goto 与 label 之间
- if loc.start < obj.start or loc.finish > label.finish then
- goto CONTINUE
- end
- -- 检查局部变量的使用位置在 label 之后
- local refs = loc.ref
- if not refs then
- goto CONTINUE
- end
- for j = 1, #refs do
- local ref = refs[j]
- if ref.finish > label.finish then
- pushError {
- type = 'JUMP_LOCAL_SCOPE',
- start = obj.start,
- finish = obj.finish,
- info = {
- loc = loc[1],
- },
- relative = {
- {
- start = label.start,
- finish = label.finish,
- },
- {
- start = loc.start,
- finish = loc.finish,
- }
- },
- }
- return
- end
- end
- ::CONTINUE::
- end
-end
-
-local function resolveGoTo(gotos)
- for i = 1, #gotos do
- local action = gotos[i]
- local label = guide.getLabel(action, action[1])
- if label then
- resolveLable(label, action)
- else
- pushError {
- type = 'NO_VISIBLE_LABEL',
- start = action.start,
- finish = action.finish,
- info = {
- label = action[1],
- }
- }
- end
- end
-end
-
-local function popChunk()
- local chunk = Chunk[#Chunk]
- if chunk.gotos then
- resolveGoTo(chunk.gotos)
- chunk.gotos = nil
- end
- local lastAction = chunk[#chunk]
- if lastAction then
- chunk.finish = lastAction.finish
- end
- Chunk[#Chunk] = nil
-end
-
-local function parseNil()
- if Tokens[Index + 1] ~= 'nil' then
- return nil
- end
- local offset = Tokens[Index]
- Index = Index + 2
- return {
- type = 'nil',
- start = getPosition(offset, 'left'),
- finish = getPosition(offset + 2, 'right'),
- }
-end
-
-local function parseBoolean()
- local word = Tokens[Index+1]
- if word ~= 'true'
- and word ~= 'false' then
- return nil
- end
- local start = getPosition(Tokens[Index], 'left')
- local finish = getPosition(Tokens[Index] + #word - 1, 'right')
- Index = Index + 2
- return {
- type = 'boolean',
- start = start,
- finish = finish,
- [1] = word == 'true' and true or false,
- }
-end
-
-local function parseStringUnicode()
- local offset = Tokens[Index] + 1
- if ssub(Lua, offset, offset) ~= '{' then
- local pos = getPosition(offset, 'left')
- missSymbol('{', pos)
- return nil, offset
- end
- local leftPos = getPosition(offset, 'left')
- local x16 = smatch(Lua, '^%w*', offset + 1)
- local rightPos = getPosition(offset + #x16, 'right')
- offset = offset + #x16 + 1
- if ssub(Lua, offset, offset) == '}' then
- offset = offset + 1
- rightPos = rightPos + 1
- else
- missSymbol('}', rightPos)
- end
- offset = offset + 1
- if #x16 == 0 then
- pushError {
- type = 'UTF8_SMALL',
- start = leftPos,
- finish = rightPos,
- }
- return '', offset
- end
- if State.version ~= 'Lua 5.3'
- and State.version ~= 'Lua 5.4'
- and State.version ~= 'LuaJIT'
- then
- pushError {
- type = 'ERR_ESC',
- start = leftPos - 2,
- finish = rightPos,
- version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
- info = {
- version = State.version,
- }
- }
- return nil, offset
- end
- local byte = tonumber(x16, 16)
- if not byte then
- for i = 1, #x16 do
- if not tonumber(ssub(x16, i, i), 16) then
- pushError {
- type = 'MUST_X16',
- start = leftPos + i,
- finish = leftPos + i + 1,
- }
- end
- end
- return nil, offset
- end
- if State.version == 'Lua 5.4' then
- if byte < 0 or byte > 0x7FFFFFFF then
- pushError {
- type = 'UTF8_MAX',
- start = leftPos,
- finish = rightPos,
- info = {
- min = '00000000',
- max = '7FFFFFFF',
- }
- }
- return nil, offset
- end
- else
- if byte < 0 or byte > 0x10FFFF then
- pushError {
- type = 'UTF8_MAX',
- start = leftPos,
- finish = rightPos,
- version = byte <= 0x7FFFFFFF and 'Lua 5.4' or nil,
- info = {
- min = '000000',
- max = '10FFFF',
- }
- }
- end
- end
- if byte >= 0 and byte <= 0x10FFFF then
- return uchar(byte), offset
- end
- return '', offset
-end
-
-local stringPool = {}
-local function parseShortString()
- local mark = Tokens[Index+1]
- local startOffset = Tokens[Index]
- local startPos = getPosition(startOffset, 'left')
- Index = Index + 2
- local stringIndex = 0
- local currentOffset = startOffset + 1
- local escs = {}
- while true do
- local token = Tokens[Index + 1]
- if token == mark then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1)
- Index = Index + 2
- break
- end
- if NLMap[token] then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1)
- missSymbol(mark)
- break
- end
- if not token then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = ssub(Lua, currentOffset or -1)
- missSymbol(mark)
- break
- end
- if token == '\\' then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1)
- currentOffset = Tokens[Index]
- Index = Index + 2
- if not Tokens[Index] then
- goto CONTINUE
- end
- local escLeft = getPosition(currentOffset, 'left')
- -- has space?
- if Tokens[Index] - currentOffset > 1 then
- local right = getPosition(currentOffset + 1, 'right')
- pushError {
- type = 'ERR_ESC',
- start = escLeft,
- finish = right,
- }
- escs[#escs+1] = escLeft
- escs[#escs+1] = right
- escs[#escs+1] = 'err'
- goto CONTINUE
- end
- local nextToken = ssub(Tokens[Index + 1], 1, 1)
- if EscMap[nextToken] then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = EscMap[nextToken]
- currentOffset = Tokens[Index] + #nextToken
- Index = Index + 2
- escs[#escs+1] = escLeft
- escs[#escs+1] = escLeft + 2
- escs[#escs+1] = 'normal'
- goto CONTINUE
- end
- if nextToken == mark then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = mark
- currentOffset = Tokens[Index] + #nextToken
- Index = Index + 2
- escs[#escs+1] = escLeft
- escs[#escs+1] = escLeft + 2
- escs[#escs+1] = 'normal'
- goto CONTINUE
- end
- if nextToken == 'z' then
- Index = Index + 2
- repeat until not skipNL()
- currentOffset = Tokens[Index]
- escs[#escs+1] = escLeft
- escs[#escs+1] = escLeft + 2
- escs[#escs+1] = 'normal'
- goto CONTINUE
- end
- if CharMapNumber[nextToken] then
- local numbers = smatch(Tokens[Index + 1], '^%d+')
- if #numbers > 3 then
- numbers = ssub(numbers, 1, 3)
- end
- currentOffset = Tokens[Index] + #numbers
- fastForwardToken(currentOffset)
- local right = getPosition(currentOffset - 1, 'right')
- local byte = tointeger(numbers)
- if byte <= 255 then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = schar(byte)
- else
- pushError {
- type = 'ERR_ESC',
- start = escLeft,
- finish = right,
- }
- end
- escs[#escs+1] = escLeft
- escs[#escs+1] = right
- escs[#escs+1] = 'byte'
- goto CONTINUE
- end
- if nextToken == 'x' then
- local left = getPosition(Tokens[Index] - 1, 'left')
- local x16 = ssub(Tokens[Index + 1], 2, 3)
- local byte = tonumber(x16, 16)
- if byte then
- currentOffset = Tokens[Index] + 3
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = schar(byte)
- else
- currentOffset = Tokens[Index] + 1
- pushError {
- type = 'MISS_ESC_X',
- start = getPosition(currentOffset, 'left'),
- finish = getPosition(currentOffset + 1, 'right'),
- }
- end
- local right = getPosition(currentOffset + 1, 'right')
- escs[#escs+1] = escLeft
- escs[#escs+1] = right
- escs[#escs+1] = 'byte'
- if State.version == 'Lua 5.1' then
- pushError {
- type = 'ERR_ESC',
- start = left,
- finish = left + 4,
- version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
- info = {
- version = State.version,
- }
- }
- end
- Index = Index + 2
- goto CONTINUE
- end
- if nextToken == 'u' then
- local str, newOffset = parseStringUnicode()
- if str then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = str
- end
- currentOffset = newOffset
- fastForwardToken(currentOffset - 1)
- local right = getPosition(currentOffset + 1, 'right')
- escs[#escs+1] = escLeft
- escs[#escs+1] = right
- escs[#escs+1] = 'unicode'
- goto CONTINUE
- end
- if NLMap[nextToken] then
- stringIndex = stringIndex + 1
- stringPool[stringIndex] = '\n'
- currentOffset = Tokens[Index] + #nextToken
- skipNL()
- local right = getPosition(currentOffset + 1, 'right')
- escs[#escs+1] = escLeft
- escs[#escs+1] = escLeft + 1
- escs[#escs+1] = 'normal'
- goto CONTINUE
- end
- local right = getPosition(currentOffset + 1, 'right')
- pushError {
- type = 'ERR_ESC',
- start = escLeft,
- finish = right,
- }
- escs[#escs+1] = escLeft
- escs[#escs+1] = right
- escs[#escs+1] = 'err'
- end
- Index = Index + 2
- ::CONTINUE::
- end
- local stringResult = tconcat(stringPool, '', 1, stringIndex)
- local str = {
- type = 'string',
- start = startPos,
- finish = lastRightPosition(),
- escs = #escs > 0 and escs or nil,
- [1] = stringResult,
- [2] = mark,
- }
- if mark == '`' then
- if not State.options.nonstandardSymbol[mark] then
- pushError {
- type = 'ERR_NONSTANDARD_SYMBOL',
- start = startPos,
- finish = str.finish,
- info = {
- symbol = '"',
- },
- fix = {
- title = 'FIX_NONSTANDARD_SYMBOL',
- symbol = '"',
- {
- start = startPos,
- finish = startPos + 1,
- text = '"',
- },
- {
- start = str.finish - 1,
- finish = str.finish,
- text = '"',
- },
- }
- }
- end
- end
- return str
-end
-
-local function parseString()
- local c = Tokens[Index + 1]
- if CharMapStrSH[c] then
- return parseShortString()
- end
- if CharMapStrLH[c] then
- return parseLongString()
- end
- return nil
-end
-
-local function parseNumber10(start)
- local integer = true
- local integerPart = smatch(Lua, '^%d*', start)
- local offset = start + #integerPart
- -- float part
- if ssub(Lua, offset, offset) == '.' then
- local floatPart = smatch(Lua, '^%d*', offset + 1)
- integer = false
- offset = offset + #floatPart + 1
- end
- -- exp part
- local echar = ssub(Lua, offset, offset)
- if CharMapE10[echar] then
- integer = false
- offset = offset + 1
- local nextChar = ssub(Lua, offset, offset)
- if CharMapSign[nextChar] then
- offset = offset + 1
- end
- local exp = smatch(Lua, '^%d*', offset)
- offset = offset + #exp
- if #exp == 0 then
- pushError {
- type = 'MISS_EXPONENT',
- start = getPosition(offset - 1, 'right'),
- finish = getPosition(offset - 1, 'right'),
- }
- end
- end
- return tonumber(ssub(Lua, start, offset - 1)), offset, integer
-end
-
-local function parseNumber16(start)
- local integerPart = smatch(Lua, '^[%da-fA-F]*', start)
- local offset = start + #integerPart
- local integer = true
- -- float part
- if ssub(Lua, offset, offset) == '.' then
- local floatPart = smatch(Lua, '^[%da-fA-F]*', offset + 1)
- integer = false
- offset = offset + #floatPart + 1
- if #integerPart == 0 and #floatPart == 0 then
- pushError {
- type = 'MUST_X16',
- start = getPosition(offset - 1, 'right'),
- finish = getPosition(offset - 1, 'right'),
- }
- end
- else
- if #integerPart == 0 then
- pushError {
- type = 'MUST_X16',
- start = getPosition(offset - 1, 'right'),
- finish = getPosition(offset - 1, 'right'),
- }
- return 0, offset
- end
- end
- -- exp part
- local echar = ssub(Lua, offset, offset)
- if CharMapE16[echar] then
- integer = false
- offset = offset + 1
- local nextChar = ssub(Lua, offset, offset)
- if CharMapSign[nextChar] then
- offset = offset + 1
- end
- local exp = smatch(Lua, '^%d*', offset)
- offset = offset + #exp
- end
- local n = tonumber(ssub(Lua, start - 2, offset - 1))
- return n, offset, integer
-end
-
-local function parseNumber2(start)
- local bins = smatch(Lua, '^[01]*', start)
- local offset = start + #bins
- if State.version ~= 'LuaJIT' then
- pushError {
- type = 'UNSUPPORT_SYMBOL',
- start = getPosition(start - 2, 'left'),
- finish = getPosition(offset - 1, 'right'),
- version = 'LuaJIT',
- info = {
- version = 'Lua 5.4',
- }
- }
- end
- return tonumber(bins, 2), offset
-end
-
-local function dropNumberTail(offset, integer)
- local _, finish, word = sfind(Lua, '^([%.%w_\x80-\xff]+)', offset)
- if not finish then
- return offset
- end
- if integer then
- if supper(ssub(word, 1, 2)) == 'LL' then
- if State.version ~= 'LuaJIT' then
- pushError {
- type = 'UNSUPPORT_SYMBOL',
- start = getPosition(offset, 'left'),
- finish = getPosition(offset + 1, 'right'),
- version = 'LuaJIT',
- info = {
- version = State.version,
- }
- }
- end
- offset = offset + 2
- word = ssub(word, offset)
- elseif supper(ssub(word, 1, 3)) == 'ULL' then
- if State.version ~= 'LuaJIT' then
- pushError {
- type = 'UNSUPPORT_SYMBOL',
- start = getPosition(offset, 'left'),
- finish = getPosition(offset + 2, 'right'),
- version = 'LuaJIT',
- info = {
- version = State.version,
- }
- }
- end
- offset = offset + 3
- word = ssub(word, offset)
- end
- end
- if supper(ssub(word, 1, 1)) == 'I' then
- if State.version ~= 'LuaJIT' then
- pushError {
- type = 'UNSUPPORT_SYMBOL',
- start = getPosition(offset, 'left'),
- finish = getPosition(offset, 'right'),
- version = 'LuaJIT',
- info = {
- version = State.version,
- }
- }
- end
- offset = offset + 1
- word = ssub(word, offset)
- end
- if #word > 0 then
- pushError {
- type = 'MALFORMED_NUMBER',
- start = getPosition(offset, 'left'),
- finish = getPosition(finish, 'right'),
- }
- end
- return finish + 1
-end
-
-local function parseNumber()
- local offset = Tokens[Index]
- if not offset then
- return nil
- end
- local startPos = getPosition(offset, 'left')
- local neg
- if ssub(Lua, offset, offset) == '-' then
- neg = true
- offset = offset + 1
- end
- local number, integer
- local firstChar = ssub(Lua, offset, offset)
- if firstChar == '.' then
- number, offset = parseNumber10(offset)
- integer = false
- elseif firstChar == '0' then
- local nextChar = ssub(Lua, offset + 1, offset + 1)
- if CharMapN16[nextChar] then
- number, offset, integer = parseNumber16(offset + 2)
- elseif CharMapN2[nextChar] then
- number, offset = parseNumber2(offset + 2)
- integer = true
- else
- number, offset, integer = parseNumber10(offset)
- end
- elseif CharMapNumber[firstChar] then
- number, offset, integer = parseNumber10(offset)
- else
- return nil
- end
- if not number then
- number = 0
- end
- if neg then
- number = - number
- end
- local result = {
- type = integer and 'integer' or 'number',
- start = startPos,
- finish = getPosition(offset - 1, 'right'),
- [1] = number,
- }
- offset = dropNumberTail(offset, integer)
- fastForwardToken(offset)
- return result
-end
-
-local function isKeyWord(word)
- if KeyWord[word] then
- return true
- end
- if word == 'goto' then
- return State.version ~= 'Lua 5.1'
- end
- return false
-end
-
-local function parseName(asAction)
- local word = peekWord()
- if not word then
- return nil
- end
- if ChunkFinishMap[word] then
- return nil
- end
- if asAction and ChunkStartMap[word] then
- return nil
- end
- local startPos = getPosition(Tokens[Index], 'left')
- local finishPos = getPosition(Tokens[Index] + #word - 1, 'right')
- Index = Index + 2
- if not State.options.unicodeName and word:find '[\x80-\xff]' then
- pushError {
- type = 'UNICODE_NAME',
- start = startPos,
- finish = finishPos,
- }
- end
- if isKeyWord(word) then
- pushError {
- type = 'KEYWORD',
- start = startPos,
- finish = finishPos,
- }
- end
- return {
- type = 'name',
- start = startPos,
- finish = finishPos,
- [1] = word,
- }
-end
-
-local function parseNameOrList(parent)
- local first = parseName()
- if not first then
- return nil
- end
- skipSpace()
- local list
- while true do
- if Tokens[Index + 1] ~= ',' then
- break
- end
- Index = Index + 2
- skipSpace()
- local name = parseName(true)
- if not name then
- missName()
- break
- end
- if not list then
- list = {
- type = 'list',
- start = first.start,
- finish = first.finish,
- parent = parent,
- [1] = first
- }
- end
- list[#list+1] = name
- list.finish = name.finish
- end
- return list or first
-end
-
-local function dropTail()
- local token = Tokens[Index + 1]
- if token ~= '?'
- and token ~= ':' then
- return
- end
- local pl, pt, pp = 0, 0, 0
- while true do
- local token = Tokens[Index + 1]
- if not token then
- break
- end
- if NLMap[token] then
- break
- end
- if token == ',' then
- if pl > 0
- or pt > 0
- or pp > 0 then
- goto CONTINUE
- else
- break
- end
- end
- if token == '<' then
- pl = pl + 1
- goto CONTINUE
- end
- if token == '{' then
- pt = pt + 1
- goto CONTINUE
- end
- if token == '(' then
- pp = pp + 1
- goto CONTINUE
- end
- if token == '>' then
- if pl <= 0 then
- break
- end
- pl = pl - 1
- goto CONTINUE
- end
- if token == '}' then
- if pt <= 0 then
- break
- end
- pt = pt - 1
- goto CONTINUE
- end
- if token == ')' then
- if pp <= 0 then
- break
- end
- pp = pp - 1
- goto CONTINUE
- end
- ::CONTINUE::
- Index = Index + 2
- end
-end
-
-local function parseExpList(mini)
- local list
- local wantSep = false
- while true do
- skipSpace()
- local token = Tokens[Index + 1]
- if not token then
- break
- end
- if ListFinishMap[token] then
- break
- end
- if token == ',' then
- local sepPos = getPosition(Tokens[Index], 'right')
- if not wantSep then
- pushError {
- type = 'UNEXPECT_SYMBOL',
- start = getPosition(Tokens[Index], 'left'),
- finish = sepPos,
- info = {
- symbol = ',',
- }
- }
- end
- wantSep = false
- Index = Index + 2
- goto CONTINUE
- else
- if mini then
- if wantSep then
- break
- end
- local nextToken = peekWord()
- if isKeyWord(nextToken)
- and nextToken ~= 'function'
- and nextToken ~= 'true'
- and nextToken ~= 'false'
- and nextToken ~= 'nil'
- and nextToken ~= 'not' then
- break
- end
- end
- local exp = parseExp()
- if not exp then
- break
- end
- dropTail()
- if wantSep then
- missSymbol(',', list[#list].finish, exp.start)
- end
- wantSep = true
- if not list then
- list = {
- type = 'list',
- start = exp.start,
- }
- end
- list[#list+1] = exp
- list.finish = exp.finish
- exp.parent = list
- end
- ::CONTINUE::
- end
- if not list then
- return nil
- end
- if not wantSep then
- missExp()
- end
- return list
-end
-
-local function parseIndex()
- local start = getPosition(Tokens[Index], 'left')
- Index = Index + 2
- skipSpace()
- local exp = parseExp()
- local index = {
- type = 'index',
- start = start,
- finish = exp and exp.finish or (start + 1),
- index = exp
- }
- if exp then
- exp.parent = index
- else
- missExp()
- end
- skipSpace()
- if Tokens[Index + 1] == ']' then
- index.finish = getPosition(Tokens[Index], 'right')
- Index = Index + 2
- else
- missSymbol ']'
- end
- return index
-end
-
-local function parseTable()
- local tbl = {
- type = 'table',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index], 'right'),
- }
- Index = Index + 2
- local index = 0
- local tindex = 0
- local wantSep = false
- while true do
- skipSpace(true)
- local token = Tokens[Index + 1]
- if token == '}' then
- Index = Index + 2
- break
- end
- if CharMapTSep[token] then
- if not wantSep then
- missExp()
- end
- wantSep = false
- Index = Index + 2
- goto CONTINUE
- end
- local lastRight = lastRightPosition()
-
- if peekWord() then
- local savePoint = getSavePoint()
- local name = parseName()
- if name then
- skipSpace()
- if Tokens[Index + 1] == '=' then
- Index = Index + 2
- if wantSep then
- pushError {
- type = 'MISS_SEP_IN_TABLE',
- start = lastRight,
- finish = getPosition(Tokens[Index], 'left'),
- }
- end
- wantSep = true
- local eqRight = lastRightPosition()
- skipSpace()
- local fvalue = parseExp()
- local tfield = {
- type = 'tablefield',
- start = name.start,
- finish = fvalue and fvalue.finish or eqRight,
- parent = tbl,
- field = name,
- value = fvalue,
- }
- name.type = 'field'
- name.parent = tfield
- if fvalue then
- fvalue.parent = tfield
- else
- missExp()
- end
- index = index + 1
- tbl[index] = tfield
- goto CONTINUE
- end
- end
- savePoint()
- end
-
- local exp = parseExp(true)
- if exp then
- if wantSep then
- pushError {
- type = 'MISS_SEP_IN_TABLE',
- start = lastRight,
- finish = exp.start,
- }
- end
- wantSep = true
- if exp.type == 'varargs' then
- index = index + 1
- tbl[index] = exp
- exp.parent = tbl
- goto CONTINUE
- end
- index = index + 1
- tindex = tindex + 1
- local texp = {
- type = 'tableexp',
- start = exp.start,
- finish = exp.finish,
- tindex = tindex,
- parent = tbl,
- value = exp,
- }
- exp.parent = texp
- tbl[index] = texp
- goto CONTINUE
- end
-
- if token == '[' then
- if wantSep then
- pushError {
- type = 'MISS_SEP_IN_TABLE',
- start = lastRight,
- finish = getPosition(Tokens[Index], 'left'),
- }
- end
- wantSep = true
- local tindex = parseIndex()
- skipSpace()
- tindex.type = 'tableindex'
- tindex.parent = tbl
- index = index + 1
- tbl[index] = tindex
- if expectAssign() then
- skipSpace()
- local ivalue = parseExp()
- if ivalue then
- ivalue.parent = tindex
- tindex.finish = ivalue.finish
- tindex.value = ivalue
- else
- missExp()
- end
- else
- missSymbol '='
- end
- goto CONTINUE
- end
-
- missSymbol '}'
- break
- ::CONTINUE::
- end
- tbl.finish = lastRightPosition()
- return tbl
-end
-
-local function addDummySelf(node, call)
- if node.type ~= 'getmethod' then
- return
- end
- -- dummy param `self`
- if not call.args then
- call.args = {
- type = 'callargs',
- start = call.start,
- finish = call.finish,
- parent = call,
- }
- end
- local self = {
- type = 'self',
- start = node.colon.start,
- finish = node.colon.finish,
- parent = call.args,
- [1] = 'self',
- }
- tinsert(call.args, 1, self)
-end
-
-local function parseSimple(node, funcName)
- local lastMethod
- while true do
- if lastMethod and node.node == lastMethod then
- if node.type ~= 'call' then
- missSymbol('(', node.node.finish, node.node.finish)
- end
- lastMethod = nil
- end
- skipSpace()
- local token = Tokens[Index + 1]
- if token == '.' then
- local dot = {
- type = token,
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index], 'right'),
- }
- Index = Index + 2
- skipSpace()
- local field = parseName(true)
- local getfield = {
- type = 'getfield',
- start = node.start,
- finish = lastRightPosition(),
- node = node,
- dot = dot,
- field = field
- }
- if field then
- field.parent = getfield
- field.type = 'field'
- else
- pushError {
- type = 'MISS_FIELD',
- start = lastRightPosition(),
- finish = lastRightPosition(),
- }
- end
- node.parent = getfield
- node.next = getfield
- node = getfield
- elseif token == ':' then
- local colon = {
- type = token,
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index], 'right'),
- }
- Index = Index + 2
- skipSpace()
- local method = parseName(true)
- local getmethod = {
- type = 'getmethod',
- start = node.start,
- finish = lastRightPosition(),
- node = node,
- colon = colon,
- method = method
- }
- if method then
- method.parent = getmethod
- method.type = 'method'
- else
- pushError {
- type = 'MISS_METHOD',
- start = lastRightPosition(),
- finish = lastRightPosition(),
- }
- end
- node.parent = getmethod
- node.next = getmethod
- node = getmethod
- if lastMethod then
- missSymbol('(', node.node.finish, node.node.finish)
- end
- lastMethod = getmethod
- elseif token == '(' then
- if funcName then
- break
- end
- local startPos = getPosition(Tokens[Index], 'left')
- local call = {
- type = 'call',
- start = node.start,
- node = node,
- }
- Index = Index + 2
- local args = parseExpList()
- if Tokens[Index + 1] == ')' then
- call.finish = getPosition(Tokens[Index], 'right')
- Index = Index + 2
- else
- call.finish = lastRightPosition()
- missSymbol ')'
- end
- if args then
- args.type = 'callargs'
- args.start = startPos
- args.finish = call.finish
- args.parent = call
- call.args = args
- end
- addDummySelf(node, call)
- node.parent = call
- node = call
- elseif token == '{' then
- if funcName then
- break
- end
- local tbl = parseTable()
- local call = {
- type = 'call',
- start = node.start,
- finish = tbl.finish,
- node = node,
- }
- local args = {
- type = 'callargs',
- start = tbl.start,
- finish = tbl.finish,
- parent = call,
- [1] = tbl,
- }
- call.args = args
- addDummySelf(node, call)
- tbl.parent = args
- node.parent = call
- node = call
- elseif CharMapStrSH[token] then
- if funcName then
- break
- end
- local str = parseShortString()
- local call = {
- type = 'call',
- start = node.start,
- finish = str.finish,
- node = node,
- }
- local args = {
- type = 'callargs',
- start = str.start,
- finish = str.finish,
- parent = call,
- [1] = str,
- }
- call.args = args
- addDummySelf(node, call)
- str.parent = args
- node.parent = call
- node = call
- elseif CharMapStrLH[token] then
- local str = parseLongString()
- if str then
- if funcName then
- break
- end
- local call = {
- type = 'call',
- start = node.start,
- finish = str.finish,
- node = node,
- }
- local args = {
- type = 'callargs',
- start = str.start,
- finish = str.finish,
- parent = call,
- [1] = str,
- }
- call.args = args
- addDummySelf(node, call)
- str.parent = args
- node.parent = call
- node = call
- else
- local index = parseIndex()
- local bstart = index.start
- index.type = 'getindex'
- index.start = node.start
- index.node = node
- node.next = index
- node.parent = index
- node = index
- if funcName then
- pushError {
- type = 'INDEX_IN_FUNC_NAME',
- start = bstart,
- finish = index.finish,
- }
- end
- end
- else
- break
- end
- end
- if node.type == 'call'
- and node.node == lastMethod then
- lastMethod = nil
- end
- if node == lastMethod then
- if funcName then
- lastMethod = nil
- end
- end
- if lastMethod then
- missSymbol('(', lastMethod.finish)
- end
- return node
-end
-
-local function parseVarargs()
- local varargs = {
- type = 'varargs',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + 2, 'right'),
- }
- Index = Index + 2
- for i = #Chunk, 1, -1 do
- local chunk = Chunk[i]
- if chunk.vararg then
- if not chunk.vararg.ref then
- chunk.vararg.ref = {}
- end
- chunk.vararg.ref[#chunk.vararg.ref+1] = varargs
- varargs.node = chunk.vararg
- break
- end
- if chunk.type == 'main' then
- break
- end
- if chunk.type == 'function' then
- pushError {
- type = 'UNEXPECT_DOTS',
- start = varargs.start,
- finish = varargs.finish,
- }
- break
- end
- end
- return varargs
-end
-
-local function parseParen()
- local pl = Tokens[Index]
- local paren = {
- type = 'paren',
- start = getPosition(pl, 'left'),
- finish = getPosition(pl, 'right')
- }
- Index = Index + 2
- skipSpace()
- local exp = parseExp()
- if exp then
- paren.exp = exp
- paren.finish = exp.finish
- exp.parent = paren
- else
- missExp()
- end
- skipSpace()
- if Tokens[Index + 1] == ')' then
- paren.finish = getPosition(Tokens[Index], 'right')
- Index = Index + 2
- else
- missSymbol ')'
- end
- return paren
-end
-
-local function getLocal(name, pos)
- for i = #Chunk, 1, -1 do
- local chunk = Chunk[i]
- local locals = chunk.locals
- if locals then
- local res
- for n = 1, #locals do
- local loc = locals[n]
- if loc.effect > pos then
- break
- end
- if loc[1] == name then
- if not res or res.effect < loc.effect then
- res = loc
- end
- end
- end
- if res then
- return res
- end
- end
- end
-end
-
-local function resolveName(node)
- if not node then
- return nil
- end
- local loc = getLocal(node[1], node.start)
- if loc then
- node.type = 'getlocal'
- node.node = loc
- if not loc.ref then
- loc.ref = {}
- end
- loc.ref[#loc.ref+1] = node
- if loc.special then
- addSpecial(loc.special, node)
- end
- else
- node.type = 'getglobal'
- local env = getLocal(State.ENVMode, node.start)
- if env then
- node.node = env
- if not env.ref then
- env.ref = {}
- end
- env.ref[#env.ref+1] = node
- end
- end
- local name = node[1]
- if Specials[name] then
- addSpecial(name, node)
- else
- local ospeicals = State.options.special
- if ospeicals and ospeicals[name] then
- addSpecial(ospeicals[name], node)
- end
- end
- return node
-end
-
-local function isChunkFinishToken(token)
- local currentChunk = Chunk[#Chunk]
- if not currentChunk then
- return false
- end
- local tp = currentChunk.type
- if tp == 'main' then
- return false
- end
- if tp == 'for'
- or tp == 'in'
- or tp == 'loop'
- or tp == 'function' then
- return token == 'end'
- end
- if tp == 'if'
- or tp == 'ifblock'
- or tp == 'elseifblock'
- or tp == 'elseblock' then
- return token == 'then'
- or token == 'end'
- or token == 'else'
- or token == 'elseif'
- end
- if tp == 'repeat' then
- return token == 'until'
- end
- return true
-end
-
-local function parseActions()
- local rtn, last
- while true do
- skipSpace(true)
- local token = Tokens[Index + 1]
- if token == ';' then
- Index = Index + 2
- goto CONTINUE
- end
- if ChunkFinishMap[token]
- and isChunkFinishToken(token) then
- break
- end
- local action, failed = parseAction()
- if failed then
- if not skipUnknownSymbol() then
- break
- end
- end
- if action then
- if not rtn and action.type == 'return' then
- rtn = action
- end
- last = action
- end
- ::CONTINUE::
- end
- if rtn and rtn ~= last then
- pushError {
- type = 'ACTION_AFTER_RETURN',
- start = rtn.start,
- finish = rtn.finish,
- }
- end
-end
-
-local function parseParams(params)
- local lastSep
- local hasDots
- while true do
- skipSpace()
- local token = Tokens[Index + 1]
- if not token or token == ')' then
- if lastSep then
- missName()
- end
- break
- end
- if token == ',' then
- if lastSep or lastSep == nil then
- missName()
- else
- lastSep = true
- end
- Index = Index + 2
- goto CONTINUE
- end
- if token == '...' then
- if lastSep == false then
- missSymbol ','
- end
- lastSep = false
- if not params then
- params = {}
- end
- local vararg = {
- type = '...',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + 2, 'right'),
- parent = params,
- [1] = '...',
- }
- local chunk = Chunk[#Chunk]
- chunk.vararg = vararg
- params[#params+1] = vararg
- if hasDots then
- pushError {
- type = 'ARGS_AFTER_DOTS',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + 2, 'right'),
- }
- end
- hasDots = true
- Index = Index + 2
- goto CONTINUE
- end
- if CharMapWord[ssub(token, 1, 1)] then
- if lastSep == false then
- missSymbol ','
- end
- lastSep = false
- if not params then
- params = {}
- end
- params[#params+1] = createLocal {
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + #token - 1, 'right'),
- parent = params,
- [1] = token,
- }
- if hasDots then
- pushError {
- type = 'ARGS_AFTER_DOTS',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + #token - 1, 'right'),
- }
- end
- if isKeyWord(token) then
- pushError {
- type = 'KEYWORD',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + #token - 1, 'right'),
- }
- end
- Index = Index + 2
- goto CONTINUE
- end
- skipUnknownSymbol '%,%)%.'
- ::CONTINUE::
- end
- return params
-end
-
-local function parseFunction(isLocal, isAction)
- local funcLeft = getPosition(Tokens[Index], 'left')
- local funcRight = getPosition(Tokens[Index] + 7, 'right')
- local func = {
- type = 'function',
- start = funcLeft,
- finish = funcRight,
- keyword = {
- [1] = funcLeft,
- [2] = funcRight,
- },
- }
- Index = Index + 2
- local LastLocalCount = LocalCount
- LocalCount = 0
- skipSpace(true)
- local hasLeftParen = Tokens[Index + 1] == '('
- if not hasLeftParen then
- local name = parseName()
- if name then
- local simple = parseSimple(name, true)
- if isLocal then
- if simple == name then
- createLocal(name)
- else
- resolveName(name)
- pushError {
- type = 'UNEXPECT_LFUNC_NAME',
- start = simple.start,
- finish = simple.finish,
- }
- end
- else
- resolveName(name)
- end
- func.name = simple
- func.finish = simple.finish
- if not isAction then
- simple.parent = func
- pushError {
- type = 'UNEXPECT_EFUNC_NAME',
- start = simple.start,
- finish = simple.finish,
- }
- end
- skipSpace(true)
- hasLeftParen = Tokens[Index + 1] == '('
- end
- end
- pushChunk(func)
- local params
- if func.name and func.name.type == 'getmethod' then
- if func.name.type == 'getmethod' then
- params = {}
- params[1] = createLocal {
- start = funcRight,
- finish = funcRight,
- parent = params,
- [1] = 'self',
- }
- params[1].type = 'self'
- end
- end
- if hasLeftParen then
- local parenLeft = getPosition(Tokens[Index], 'left')
- Index = Index + 2
- params = parseParams(params)
- if params then
- params.type = 'funcargs'
- params.start = parenLeft
- params.finish = lastRightPosition()
- params.parent = func
- func.args = params
- end
- skipSpace(true)
- if Tokens[Index + 1] == ')' then
- local parenRight = getPosition(Tokens[Index], 'right')
- func.finish = parenRight
- if params then
- params.finish = parenRight
- end
- Index = Index + 2
- skipSpace(true)
- else
- func.finish = lastRightPosition()
- if params then
- params.finish = func.finish
- end
- missSymbol ')'
- end
- else
- missSymbol '('
- end
- parseActions()
- popChunk()
- if Tokens[Index + 1] == 'end' then
- local endLeft = getPosition(Tokens[Index], 'left')
- local endRight = getPosition(Tokens[Index] + 2, 'right')
- func.keyword[3] = endLeft
- func.keyword[4] = endRight
- func.finish = endRight
- Index = Index + 2
- else
- missEnd(funcLeft, funcRight)
- end
- LocalCount = LastLocalCount
- return func
-end
-
-local function parseExpUnit()
- local token = Tokens[Index + 1]
- if token == '(' then
- local paren = parseParen()
- return parseSimple(paren, false)
- end
-
- if token == '...' then
- local varargs = parseVarargs()
- return varargs
- end
-
- if token == '{' then
- local table = parseTable()
- return table
- end
-
- if CharMapStrSH[token] then
- local string = parseShortString()
- return string
- end
-
- if CharMapStrLH[token] then
- local string = parseLongString()
- return string
- end
-
- local number = parseNumber()
- if number then
- return number
- end
-
- if ChunkFinishMap[token] then
- return nil
- end
-
- if token == 'nil' then
- return parseNil()
- end
-
- if token == 'true'
- or token == 'false' then
- return parseBoolean()
- end
-
- if token == 'function' then
- return parseFunction()
- end
-
- local node = parseName()
- if node then
- return parseSimple(resolveName(node), false)
- end
-
- return nil
-end
-
-local function parseUnaryOP()
- local token = Tokens[Index + 1]
- local symbol = UnarySymbol[token] and token or UnaryAlias[token]
- if not symbol then
- return nil
- end
- local myLevel = UnarySymbol[symbol]
- local op = {
- type = symbol,
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + #symbol - 1, 'right'),
- }
- Index = Index + 2
- return op, myLevel
-end
-
----@param level integer # op level must greater than this level
-local function parseBinaryOP(asAction, level)
- local token = Tokens[Index + 1]
- local symbol = (BinarySymbol[token] and token)
- or BinaryAlias[token]
- or (not asAction and BinaryActionAlias[token])
- if not symbol then
- return nil
- end
- if symbol == '//' and State.options.nonstandardSymbol['//'] then
- return nil
- end
- local myLevel = BinarySymbol[symbol]
- if level and myLevel < level then
- return nil
- end
- local op = {
- type = symbol,
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + #token - 1, 'right'),
- }
- if not asAction then
- if token == '=' then
- pushError {
- type = 'ERR_EQ_AS_ASSIGN',
- start = op.start,
- finish = op.finish,
- fix = {
- title = 'FIX_EQ_AS_ASSIGN',
- {
- start = op.start,
- finish = op.finish,
- text = '==',
- }
- }
- }
- end
- end
- if BinaryAlias[token] then
- if not State.options.nonstandardSymbol[token] then
- pushError {
- type = 'ERR_NONSTANDARD_SYMBOL',
- start = op.start,
- finish = op.finish,
- info = {
- symbol = symbol,
- },
- fix = {
- title = 'FIX_NONSTANDARD_SYMBOL',
- symbol = symbol,
- {
- start = op.start,
- finish = op.finish,
- text = symbol,
- },
- }
- }
- end
- end
- if token == '//'
- or token == '<<'
- or token == '>>' then
- if State.version ~= 'Lua 5.3'
- and State.version ~= 'Lua 5.4' then
- pushError {
- type = 'UNSUPPORT_SYMBOL',
- version = {'Lua 5.3', 'Lua 5.4'},
- start = op.start,
- finish = op.finish,
- info = {
- version = State.version,
- }
- }
- end
- end
- Index = Index + 2
- return op, myLevel
-end
-
-function parseExp(asAction, level)
- local exp
- local uop, uopLevel = parseUnaryOP()
- if uop then
- skipSpace()
- local child = parseExp(asAction, uopLevel)
- -- 预计算负数
- if uop.type == '-'
- and child
- and (child.type == 'number' or child.type == 'integer') then
- child.start = uop.start
- child[1] = - child[1]
- exp = child
- else
- exp = {
- type = 'unary',
- op = uop,
- start = uop.start,
- finish = child and child.finish or uop.finish,
- [1] = child,
- }
- if child then
- child.parent = exp
- else
- missExp()
- end
- end
- else
- exp = parseExpUnit()
- if not exp then
- return nil
- end
- end
-
- while true do
- skipSpace()
- local bop, bopLevel = parseBinaryOP(asAction, level)
- if not bop then
- break
- end
-
- ::AGAIN::
- skipSpace()
- local isForward = SymbolForward[bopLevel]
- local child = parseExp(asAction, isForward and (bopLevel + 0.5) or bopLevel)
- if not child then
- if skipUnknownSymbol() then
- goto AGAIN
- else
- missExp()
- end
- end
- local bin = {
- type = 'binary',
- start = exp.start,
- finish = child and child.finish or bop.finish,
- op = bop,
- [1] = exp,
- [2] = child
- }
- exp.parent = bin
- if child then
- child.parent = bin
- end
- exp = bin
- end
-
- return exp
-end
-
-local function skipSeps()
- while true do
- skipSpace()
- if Tokens[Index + 1] == ',' then
- missExp()
- Index = Index + 2
- else
- break
- end
- end
-end
-
----@return parser.object first
----@return parser.object second
----@return parser.object[] rest
-local function parseSetValues()
- skipSpace()
- local first = parseExp()
- if not first then
- return nil
- end
- skipSpace()
- if Tokens[Index + 1] ~= ',' then
- return first
- end
- Index = Index + 2
- skipSeps()
- local second = parseExp()
- if not second then
- missExp()
- return first
- end
- skipSpace()
- if Tokens[Index + 1] ~= ',' then
- return first, second
- end
- Index = Index + 2
- skipSeps()
- local third = parseExp()
- if not third then
- missExp()
- return first, second
- end
-
- local rest = { third }
- while true do
- skipSpace()
- if Tokens[Index + 1] ~= ',' then
- return first, second, rest
- end
- Index = Index + 2
- skipSeps()
- local exp = parseExp()
- if not exp then
- missExp()
- return first, second, rest
- end
- rest[#rest+1] = exp
- end
-end
-
-local function pushActionIntoCurrentChunk(action)
- local chunk = Chunk[#Chunk]
- if chunk then
- chunk[#chunk+1] = action
- action.parent = chunk
- end
-end
-
----@return parser.object second
----@return parser.object[] rest
-local function parseVarTails(parser, isLocal)
- if Tokens[Index + 1] ~= ',' then
- return
- end
- Index = Index + 2
- skipSpace()
- local second = parser(true)
- if not second then
- missName()
- return
- end
- if isLocal then
- createLocal(second, parseLocalAttrs())
- second.effect = maxinteger
- end
- skipSpace()
- if Tokens[Index + 1] ~= ',' then
- return second
- end
- Index = Index + 2
- skipSeps()
- local third = parser(true)
- if not third then
- missName()
- return second
- end
- if isLocal then
- createLocal(third, parseLocalAttrs())
- third.effect = maxinteger
- end
- local rest = { third }
- while true do
- skipSpace()
- if Tokens[Index + 1] ~= ',' then
- return second, rest
- end
- Index = Index + 2
- skipSeps()
- local name = parser(true)
- if not name then
- missName()
- return second, rest
- end
- if isLocal then
- createLocal(name, parseLocalAttrs())
- name.effect = maxinteger
- end
- rest[#rest+1] = name
- end
-end
-
-local function bindValue(n, v, index, lastValue, isLocal, isSet)
- if isLocal then
- n.effect = lastRightPosition()
- if v and v.special then
- addSpecial(v.special, n)
- end
- elseif isSet then
- n.type = GetToSetMap[n.type] or n.type
- if n.type == 'setlocal' then
- local loc = n.node
- if loc.attrs then
- pushError {
- type = 'SET_CONST',
- start = n.start,
- finish = n.finish,
- }
- end
- end
- end
- if not v and lastValue then
- if lastValue.type == 'call'
- or lastValue.type == 'varargs' then
- v = lastValue
- if not v.extParent then
- v.extParent = {}
- end
- end
- end
- if v then
- if v.type == 'call'
- or v.type == 'varargs' then
- local select = {
- type = 'select',
- sindex = index,
- start = v.start,
- finish = v.finish,
- vararg = v
- }
- if v.parent then
- v.extParent[#v.extParent+1] = select
- else
- v.parent = select
- end
- v = select
- end
- n.value = v
- n.range = v.finish
- v.parent = n
- if isLocal then
- n.effect = lastRightPosition()
- end
- end
-end
-
-local function parseMultiVars(n1, parser, isLocal)
- local n2, nrest = parseVarTails(parser, isLocal)
- skipSpace()
- local v1, v2, vrest
- local isSet
- local max = 1
- if expectAssign(not isLocal) then
- v1, v2, vrest = parseSetValues()
- isSet = true
- if not v1 then
- missExp()
- end
- end
- bindValue(n1, v1, 1, nil, isLocal, isSet)
- local lastValue = v1
- if n2 then
- max = 2
- bindValue(n2, v2, 2, lastValue, isLocal, isSet)
- lastValue = v2 or lastValue
- pushActionIntoCurrentChunk(n2)
- end
- if nrest then
- for i = 1, #nrest do
- local n = nrest[i]
- local v = vrest and vrest[i]
- max = i + 2
- bindValue(n, v, max, lastValue, isLocal, isSet)
- lastValue = v or lastValue
- pushActionIntoCurrentChunk(n)
- end
- end
-
- if v2 and not n2 then
- v2.redundant = {
- max = max,
- passed = 2,
- }
- pushActionIntoCurrentChunk(v2)
- end
- if vrest then
- for i = 1, #vrest do
- local v = vrest[i]
- if not nrest or not nrest[i] then
- v.redundant = {
- max = max,
- passed = i + 2,
- }
- pushActionIntoCurrentChunk(v)
- end
- end
- end
-
- return n1, isSet
-end
-
-local function compileExpAsAction(exp)
- pushActionIntoCurrentChunk(exp)
- if GetToSetMap[exp.type] then
- skipSpace()
- local action, isSet = parseMultiVars(exp, parseExp)
- if isSet
- or action.type == 'getmethod' then
- return action
- end
- end
-
- if exp.type == 'call' then
- return exp
- end
-
- if exp.type == 'binary' then
- if GetToSetMap[exp[1].type] then
- local op = exp.op
- if op.type == '==' then
- pushError {
- type = 'ERR_ASSIGN_AS_EQ',
- start = op.start,
- finish = op.finish,
- fix = {
- title = 'FIX_ASSIGN_AS_EQ',
- {
- start = op.start,
- finish = op.finish,
- text = '=',
- }
- }
- }
- return
- end
- end
- end
-
- pushError {
- type = 'EXP_IN_ACTION',
- start = exp.start,
- finish = exp.finish,
- }
-
- return exp
-end
-
-local function parseLocal()
- local locPos = getPosition(Tokens[Index], 'left')
- Index = Index + 2
- skipSpace()
- local word = peekWord()
- if not word then
- missName()
- return nil
- end
-
- if word == 'function' then
- local func = parseFunction(true, true)
- local name = func.name
- if name then
- func.name = nil
- name.value = func
- name.vstart = func.start
- name.range = func.finish
- name.locPos = locPos
- func.parent = name
- pushActionIntoCurrentChunk(name)
- return name
- else
- missName(func.keyword[2])
- pushActionIntoCurrentChunk(func)
- return func
- end
- end
-
- local name = parseName(true)
- if not name then
- missName()
- return nil
- end
- local loc = createLocal(name, parseLocalAttrs())
- loc.locPos = locPos
- loc.effect = maxinteger
- pushActionIntoCurrentChunk(loc)
- skipSpace()
- parseMultiVars(loc, parseName, true)
- if loc.value then
- loc.effect = loc.value.finish
- else
- loc.effect = loc.finish
- end
-
- return loc
-end
-
-local function parseDo()
- local doLeft = getPosition(Tokens[Index], 'left')
- local doRight = getPosition(Tokens[Index] + 1, 'right')
- local obj = {
- type = 'do',
- start = doLeft,
- finish = doRight,
- keyword = {
- [1] = doLeft,
- [2] = doRight,
- },
- }
- Index = Index + 2
- pushActionIntoCurrentChunk(obj)
- pushChunk(obj)
- parseActions()
- popChunk()
- if Tokens[Index + 1] == 'end' then
- obj.finish = getPosition(Tokens[Index] + 2, 'right')
- obj.keyword[3] = getPosition(Tokens[Index], 'left')
- obj.keyword[4] = getPosition(Tokens[Index] + 2, 'right')
- Index = Index + 2
- else
- missEnd(doLeft, doRight)
- end
- if obj.locals then
- LocalCount = LocalCount - #obj.locals
- end
-
- return obj
-end
-
-local function parseReturn()
- local returnLeft = getPosition(Tokens[Index], 'left')
- local returnRight = getPosition(Tokens[Index] + 5, 'right')
- Index = Index + 2
- skipSpace()
- local rtn = parseExpList(true)
- if rtn then
- rtn.type = 'return'
- rtn.start = returnLeft
- else
- rtn = {
- type = 'return',
- start = returnLeft,
- finish = returnRight,
- }
- end
- pushActionIntoCurrentChunk(rtn)
- for i = #Chunk, 1, -1 do
- local block = Chunk[i]
- if block.type == 'function'
- or block.type == 'main' then
- if not block.returns then
- block.returns = {}
- end
- block.returns[#block.returns+1] = rtn
- break
- end
- end
- for i = #Chunk, 1, -1 do
- local block = Chunk[i]
- if block.type == 'ifblock'
- or block.type == 'elseifblock'
- or block.type == 'else' then
- block.hasReturn = true
- break
- end
- end
-
- return rtn
-end
-
-local function parseLabel()
- local left = getPosition(Tokens[Index], 'left')
- Index = Index + 2
- skipSpace()
- local label = parseName()
- skipSpace()
-
- if not label then
- missName()
- end
-
- if Tokens[Index + 1] == '::' then
- Index = Index + 2
- else
- if label then
- missSymbol '::'
- end
- end
-
- if not label then
- return nil
- end
-
- label.type = 'label'
- pushActionIntoCurrentChunk(label)
-
- local block = guide.getBlock(label)
- if block then
- if not block.labels then
- block.labels = {}
- end
- local name = label[1]
- local olabel = guide.getLabel(block, name)
- if olabel then
- if State.version == 'Lua 5.4'
- or block == guide.getBlock(olabel) then
- pushError {
- type = 'REDEFINED_LABEL',
- start = label.start,
- finish = label.finish,
- relative = {
- {
- olabel.start,
- olabel.finish,
- }
- }
- }
- end
- end
- block.labels[name] = label
- end
-
- if State.version == 'Lua 5.1' then
- pushError {
- type = 'UNSUPPORT_SYMBOL',
- start = left,
- finish = lastRightPosition(),
- version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
- info = {
- version = State.version,
- }
- }
- return
- end
- return label
-end
-
-local function parseGoTo()
- local start = getPosition(Tokens[Index], 'left')
- Index = Index + 2
- skipSpace()
-
- local action = parseName()
- if not action then
- missName()
- return nil
- end
-
- action.type = 'goto'
- action.keyStart = start
-
- for i = #Chunk, 1, -1 do
- local chunk = Chunk[i]
- if chunk.type == 'function'
- or chunk.type == 'main' then
- if not chunk.gotos then
- chunk.gotos = {}
- end
- chunk.gotos[#chunk.gotos+1] = action
- break
- end
- end
- for i = #Chunk, 1, -1 do
- local chunk = Chunk[i]
- if chunk.type == 'ifblock'
- or chunk.type == 'elseifblock'
- or chunk.type == 'elseblock' then
- chunk.hasGoTo = true
- break
- end
- end
-
- pushActionIntoCurrentChunk(action)
- return action
-end
-
-local function parseIfBlock(parent)
- local ifLeft = getPosition(Tokens[Index], 'left')
- local ifRight = getPosition(Tokens[Index] + 1, 'right')
- Index = Index + 2
- local ifblock = {
- type = 'ifblock',
- parent = parent,
- start = ifLeft,
- finish = ifRight,
- keyword = {
- [1] = ifLeft,
- [2] = ifRight,
- }
- }
- skipSpace()
- local filter = parseExp()
- if filter then
- ifblock.filter = filter
- ifblock.finish = filter.finish
- filter.parent = ifblock
- else
- missExp()
- end
- skipSpace()
- local thenToken = Tokens[Index + 1]
- if thenToken == 'then'
- or thenToken == 'do' then
- ifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right')
- ifblock.keyword[3] = getPosition(Tokens[Index], 'left')
- ifblock.keyword[4] = ifblock.finish
- if thenToken == 'do' then
- pushError {
- type = 'ERR_THEN_AS_DO',
- start = ifblock.keyword[3],
- finish = ifblock.keyword[4],
- fix = {
- title = 'FIX_THEN_AS_DO',
- {
- start = ifblock.keyword[3],
- finish = ifblock.keyword[4],
- text = 'then',
- }
- }
- }
- end
- Index = Index + 2
- else
- missSymbol 'then'
- end
- pushChunk(ifblock)
- parseActions()
- popChunk()
- ifblock.finish = lastRightPosition()
- if ifblock.locals then
- LocalCount = LocalCount - #ifblock.locals
- end
- return ifblock
-end
-
-local function parseElseIfBlock(parent)
- local ifLeft = getPosition(Tokens[Index], 'left')
- local ifRight = getPosition(Tokens[Index] + 5, 'right')
- local elseifblock = {
- type = 'elseifblock',
- parent = parent,
- start = ifLeft,
- finish = ifRight,
- keyword = {
- [1] = ifLeft,
- [2] = ifRight,
- }
- }
- Index = Index + 2
- skipSpace()
- local filter = parseExp()
- if filter then
- elseifblock.filter = filter
- elseifblock.finish = filter.finish
- filter.parent = elseifblock
- else
- missExp()
- end
- skipSpace()
- local thenToken = Tokens[Index + 1]
- if thenToken == 'then'
- or thenToken == 'do' then
- elseifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right')
- elseifblock.keyword[3] = getPosition(Tokens[Index], 'left')
- elseifblock.keyword[4] = elseifblock.finish
- if thenToken == 'do' then
- pushError {
- type = 'ERR_THEN_AS_DO',
- start = elseifblock.keyword[3],
- finish = elseifblock.keyword[4],
- fix = {
- title = 'FIX_THEN_AS_DO',
- {
- start = elseifblock.keyword[3],
- finish = elseifblock.keyword[4],
- text = 'then',
- }
- }
- }
- end
- Index = Index + 2
- else
- missSymbol 'then'
- end
- pushChunk(elseifblock)
- parseActions()
- popChunk()
- elseifblock.finish = lastRightPosition()
- if elseifblock.locals then
- LocalCount = LocalCount - #elseifblock.locals
- end
- return elseifblock
-end
-
-local function parseElseBlock(parent)
- local ifLeft = getPosition(Tokens[Index], 'left')
- local ifRight = getPosition(Tokens[Index] + 3, 'right')
- local elseblock = {
- type = 'elseblock',
- parent = parent,
- start = ifLeft,
- finish = ifRight,
- keyword = {
- [1] = ifLeft,
- [2] = ifRight,
- }
- }
- Index = Index + 2
- skipSpace()
- pushChunk(elseblock)
- parseActions()
- popChunk()
- elseblock.finish = lastRightPosition()
- if elseblock.locals then
- LocalCount = LocalCount - #elseblock.locals
- end
- return elseblock
-end
-
-local function parseIf()
- local token = Tokens[Index + 1]
- local left = getPosition(Tokens[Index], 'left')
- local action = {
- type = 'if',
- start = left,
- finish = getPosition(Tokens[Index] + #token - 1, 'right'),
- }
- pushActionIntoCurrentChunk(action)
- if token ~= 'if' then
- missSymbol('if', left, left)
- end
- local hasElse
- while true do
- local word = Tokens[Index + 1]
- local child
- if word == 'if' then
- child = parseIfBlock(action)
- elseif word == 'elseif' then
- child = parseElseIfBlock(action)
- elseif word == 'else' then
- child = parseElseBlock(action)
- end
- if not child then
- break
- end
- if hasElse then
- pushError {
- type = 'BLOCK_AFTER_ELSE',
- start = child.start,
- finish = child.finish,
- }
- end
- if word == 'else' then
- hasElse = true
- end
- action[#action+1] = child
- action.finish = child.finish
- skipSpace()
- end
-
- if Tokens[Index + 1] == 'end' then
- action.finish = getPosition(Tokens[Index] + 2, 'right')
- Index = Index + 2
- else
- missEnd(action[1].keyword[1], action[1].keyword[2])
- end
-
- return action
-end
-
-local function parseFor()
- local action = {
- type = 'for',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + 2, 'right'),
- keyword = {},
- }
- action.keyword[1] = action.start
- action.keyword[2] = action.finish
- Index = Index + 2
- pushActionIntoCurrentChunk(action)
- pushChunk(action)
- skipSpace()
- local nameOrList = parseNameOrList(action)
- if not nameOrList then
- missName()
- end
- skipSpace()
- -- for i =
- if expectAssign() then
- action.type = 'loop'
-
- skipSpace()
- local expList = parseExpList()
- local name
- if nameOrList then
- if nameOrList.type == 'name' then
- name = nameOrList
- else
- name = nameOrList[1]
- end
- end
- if name then
- local loc = createLocal(name)
- loc.parent = action
- action.finish = name.finish
- action.loc = loc
- end
- if expList then
- expList.parent = action
- local value = expList[1]
- if value then
- value.parent = expList
- action.init = value
- action.finish = expList[#expList].finish
- end
- local max = expList[2]
- if max then
- max.parent = expList
- action.max = max
- action.finish = max.finish
- else
- pushError {
- type = 'MISS_LOOP_MAX',
- start = lastRightPosition(),
- finish = lastRightPosition(),
- }
- end
- local step = expList[3]
- if step then
- step.parent = expList
- action.step = step
- action.finish = step.finish
- end
- else
- pushError {
- type = 'MISS_LOOP_MIN',
- start = lastRightPosition(),
- finish = lastRightPosition(),
- }
- end
-
- if action.loc then
- action.loc.effect = action.finish
- end
- elseif Tokens[Index + 1] == 'in' then
- action.type = 'in'
- local inLeft = getPosition(Tokens[Index], 'left')
- local inRight = getPosition(Tokens[Index] + 1, 'right')
- Index = Index + 2
- skipSpace()
-
- local exps = parseExpList()
-
- action.finish = inRight
- action.keyword[3] = inLeft
- action.keyword[4] = inRight
-
- local list
- if nameOrList and nameOrList.type == 'name' then
- list = {
- type = 'list',
- start = nameOrList.start,
- finish = nameOrList.finish,
- parent = action,
- [1] = nameOrList,
- }
- else
- list = nameOrList
- end
-
- if exps then
- local lastExp = exps[#exps]
- if lastExp then
- action.finish = lastExp.finish
- end
-
- action.exps = exps
- exps.parent = action
- for i = 1, #exps do
- local exp = exps[i]
- exp.parent = exps
- end
- else
- missExp()
- end
-
- if list then
- local lastName = list[#list]
- list.range = lastName and lastName.range or inRight
- action.keys = list
- for i = 1, #list do
- local loc = createLocal(list[i])
- loc.parent = action
- loc.effect = action.finish
- end
- end
- else
- missSymbol 'in'
- end
-
- skipSpace()
- local doToken = Tokens[Index + 1]
- if doToken == 'do'
- or doToken == 'then' then
- local left = getPosition(Tokens[Index], 'left')
- local right = getPosition(Tokens[Index] + #doToken - 1, 'right')
- action.finish = left
- action.keyword[#action.keyword+1] = left
- action.keyword[#action.keyword+1] = right
- if doToken == 'then' then
- pushError {
- type = 'ERR_DO_AS_THEN',
- start = left,
- finish = right,
- fix = {
- title = 'FIX_DO_AS_THEN',
- {
- start = left,
- finish = right,
- text = 'do',
- }
- }
- }
- end
- Index = Index + 2
- else
- missSymbol 'do'
- end
-
- skipSpace()
- parseActions()
- popChunk()
-
- skipSpace()
- if Tokens[Index + 1] == 'end' then
- action.finish = getPosition(Tokens[Index] + 2, 'right')
- action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left')
- action.keyword[#action.keyword+1] = action.finish
- Index = Index + 2
- else
- missEnd(action.keyword[1], action.keyword[2])
- end
-
- if action.locals then
- LocalCount = LocalCount - #action.locals
- end
-
- return action
-end
-
-local function parseWhile()
- local action = {
- type = 'while',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + 4, 'right'),
- keyword = {},
- }
- action.keyword[1] = action.start
- action.keyword[2] = action.finish
- Index = Index + 2
-
- skipSpace()
- local nextToken = Tokens[Index + 1]
- local filter = nextToken ~= 'do'
- and nextToken ~= 'then'
- and parseExp()
- if filter then
- action.filter = filter
- action.finish = filter.finish
- filter.parent = action
- else
- missExp()
- end
-
- skipSpace()
- local doToken = Tokens[Index + 1]
- if doToken == 'do'
- or doToken == 'then' then
- local left = getPosition(Tokens[Index], 'left')
- local right = getPosition(Tokens[Index] + #doToken - 1, 'right')
- action.finish = left
- action.keyword[#action.keyword+1] = left
- action.keyword[#action.keyword+1] = right
- if doToken == 'then' then
- pushError {
- type = 'ERR_DO_AS_THEN',
- start = left,
- finish = right,
- fix = {
- title = 'FIX_DO_AS_THEN',
- {
- start = left,
- finish = right,
- text = 'do',
- }
- }
- }
- end
- Index = Index + 2
- else
- missSymbol 'do'
- end
-
- pushActionIntoCurrentChunk(action)
- pushChunk(action)
- skipSpace()
- parseActions()
- popChunk()
-
- skipSpace()
- if Tokens[Index + 1] == 'end' then
- action.finish = getPosition(Tokens[Index] + 2, 'right')
- action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left')
- action.keyword[#action.keyword+1] = action.finish
- Index = Index + 2
- else
- missEnd(action.keyword[1], action.keyword[2])
- end
-
- if action.locals then
- LocalCount = LocalCount - #action.locals
- end
-
- return action
-end
-
-local function parseRepeat()
- local action = {
- type = 'repeat',
- start = getPosition(Tokens[Index], 'left'),
- finish = getPosition(Tokens[Index] + 5, 'right'),
- keyword = {},
- }
- action.keyword[1] = action.start
- action.keyword[2] = action.finish
- Index = Index + 2
-
- pushActionIntoCurrentChunk(action)
- pushChunk(action)
- skipSpace()
- parseActions()
-
- skipSpace()
- if Tokens[Index + 1] == 'until' then
- action.finish = getPosition(Tokens[Index] + 4, 'right')
- action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left')
- action.keyword[#action.keyword+1] = action.finish
- Index = Index + 2
-
- skipSpace()
- local filter = parseExp()
- if filter then
- action.filter = filter
- filter.parent = action
- else
- missExp()
- end
-
- else
- missSymbol 'until'
- end
-
- popChunk()
- if action.filter then
- action.finish = action.filter.finish
- end
-
- if action.locals then
- LocalCount = LocalCount - #action.locals
- end
-
- return action
-end
-
-local function parseBreak()
- local returnLeft = getPosition(Tokens[Index], 'left')
- local returnRight = getPosition(Tokens[Index] + #Tokens[Index + 1] - 1, 'right')
- Index = Index + 2
- skipSpace()
- local action = {
- type = 'break',
- start = returnLeft,
- finish = returnRight,
- }
-
- local ok
- for i = #Chunk, 1, -1 do
- local chunk = Chunk[i]
- if chunk.type == 'function' then
- break
- end
- if chunk.type == 'while'
- or chunk.type == 'in'
- or chunk.type == 'loop'
- or chunk.type == 'repeat'
- or chunk.type == 'for' then
- if not chunk.breaks then
- chunk.breaks = {}
- end
- chunk.breaks[#chunk.breaks+1] = action
- ok = true
- break
- end
- end
- for i = #Chunk, 1, -1 do
- local chunk = Chunk[i]
- if chunk.type == 'ifblock'
- or chunk.type == 'elseifblock'
- or chunk.type == 'elseblock' then
- chunk.hasBreak = true
- break
- end
- end
- if not ok and Mode == 'Lua' then
- pushError {
- type = 'BREAK_OUTSIDE',
- start = action.start,
- finish = action.finish,
- }
- end
-
- pushActionIntoCurrentChunk(action)
- return action
-end
-
-function parseAction()
- local token = Tokens[Index + 1]
-
- if token == '::' then
- return parseLabel()
- end
-
- if token == 'local' then
- return parseLocal()
- end
-
- if token == 'if'
- or token == 'elseif'
- or token == 'else' then
- return parseIf()
- end
-
- if token == 'for' then
- return parseFor()
- end
-
- if token == 'do' then
- return parseDo()
- end
-
- if token == 'return' then
- return parseReturn()
- end
-
- if token == 'break' then
- return parseBreak()
- end
-
- if token == 'continue' and State.options.nonstandardSymbol['continue'] then
- return parseBreak()
- end
-
- if token == 'while' then
- return parseWhile()
- end
-
- if token == 'repeat' then
- return parseRepeat()
- end
-
- if token == 'goto' and isKeyWord 'goto' then
- return parseGoTo()
- end
-
- if token == 'function' then
- local exp = parseFunction(false, true)
- local name = exp.name
- if name then
- exp.name = nil
- name.type = GetToSetMap[name.type]
- name.value = exp
- name.vstart = exp.start
- name.range = exp.finish
- exp.parent = name
- if name.type == 'setlocal' then
- local loc = name.node
- if loc.attrs then
- pushError {
- type = 'SET_CONST',
- start = name.start,
- finish = name.finish,
- }
- end
- end
- pushActionIntoCurrentChunk(name)
- return name
- else
- pushActionIntoCurrentChunk(exp)
- missName(exp.keyword[2])
- return exp
- end
- end
-
- local exp = parseExp(true)
- if exp then
- local action = compileExpAsAction(exp)
- if action then
- return action
- end
- end
- return nil, true
-end
-
-local function skipFirstComment()
- if Tokens[Index + 1] ~= '#' then
- return
- end
- while true do
- Index = Index + 2
- local token = Tokens[Index + 1]
- if not token then
- break
- end
- if NLMap[token] then
- skipNL()
- break
- end
- end
-end
-
-local function parseLua()
- local main = {
- type = 'main',
- start = 0,
- finish = 0,
- }
- pushChunk(main)
- createLocal{
- type = 'local',
- start = -1,
- finish = -1,
- effect = -1,
- parent = main,
- tag = '_ENV',
- special= '_G',
- [1] = State.ENVMode,
- }
- LocalCount = 0
- skipFirstComment()
- while true do
- parseActions()
- if Index <= #Tokens then
- unknownSymbol()
- Index = Index + 2
- else
- break
- end
- end
- popChunk()
- main.finish = getPosition(#Lua, 'right')
-
- return main
-end
-
-local function initState(lua, version, options)
- Lua = lua
- Line = 0
- LineOffset = 1
- LastTokenFinish = 0
- LocalCount = 0
- Chunk = {}
- Tokens = tokens(lua)
- Index = 1
- local state = {
- version = version,
- lua = lua,
- ast = {},
- errs = {},
- comms = {},
- lines = {
- [0] = 1,
- },
- options = options or {},
- }
- if not state.options.nonstandardSymbol then
- state.options.nonstandardSymbol = {}
- end
- State = state
- if version == 'Lua 5.1' or version == 'LuaJIT' then
- state.ENVMode = '@fenv'
- else
- state.ENVMode = '_ENV'
- end
-
- pushError = function (err)
- local errs = state.errs
- if err.finish < err.start then
- err.finish = err.start
- end
- local last = errs[#errs]
- if last then
- if last.start <= err.start and last.finish >= err.finish then
- return
- end
- end
- err.level = err.level or 'Error'
- errs[#errs+1] = err
- return err
- end
-
- state.pushError = pushError
-end
-
-return function (lua, mode, version, options)
- Mode = mode
- initState(lua, version, options)
- skipSpace()
- if mode == 'Lua' then
- State.ast = parseLua()
- elseif mode == 'Nil' then
- State.ast = parseNil()
- elseif mode == 'Boolean' then
- State.ast = parseBoolean()
- elseif mode == 'String' then
- State.ast = parseString()
- elseif mode == 'Number' then
- State.ast = parseNumber()
- elseif mode == 'Name' then
- State.ast = parseName()
- elseif mode == 'Exp' then
- State.ast = parseExp()
- elseif mode == 'Action' then
- State.ast = parseAction()
- end
-
- if State.ast then
- State.ast.state = State
- end
-
- while true do
- if Index <= #Tokens then
- unknownSymbol()
- Index = Index + 2
- else
- break
- end
- end
-
- return State
-end
diff --git a/script/parser/parse.lua b/script/parser/parse.lua
deleted file mode 100644
index e7c7d177..00000000
--- a/script/parser/parse.lua
+++ /dev/null
@@ -1,63 +0,0 @@
-local ast = require 'parser.ast'
-local grammar = require 'parser.grammar'
-
-local function buildState(lua, version, options)
- local errs = {}
- local diags = {}
- local comms = {}
- local state = {
- version = version,
- lua = lua,
- root = {},
- errs = errs,
- diags = diags,
- comms = comms,
- options = options or {},
- pushError = function (err)
- if err.finish < err.start then
- err.finish = err.start
- end
- local last = errs[#errs]
- if last then
- if last.start <= err.start and last.finish >= err.finish then
- return
- end
- end
- err.level = err.level or 'error'
- errs[#errs+1] = err
- return err
- end,
- pushDiag = function (code, info)
- if not diags[code] then
- diags[code] = {}
- end
- diags[code][#diags[code]+1] = info
- end,
- pushComment = function (comment)
- comms[#comms+1] = comment
- end
- }
- if version == 'Lua 5.1' or version == 'LuaJIT' then
- state.ENVMode = '@fenv'
- else
- state.ENVMode = '_ENV'
- end
- return state
-end
-
-return function (lua, mode, version, options)
- local state = buildState(lua, version, options)
- local clock = os.clock()
- ast.init(state)
- local suc, res, err = xpcall(grammar, debug.traceback, lua, mode)
- ast.close()
- if not suc then
- return nil, res
- end
- if not res and err then
- state.pushError(err)
- end
- state.ast = res
- state.parseClock = os.clock() - clock
- return state
-end
diff --git a/script/parser/split.lua b/script/parser/split.lua
deleted file mode 100644
index 6ce4a4e7..00000000
--- a/script/parser/split.lua
+++ /dev/null
@@ -1,9 +0,0 @@
-local m = require 'lpeglabel'
-
-local NL = m.P'\r\n' + m.S'\r\n'
-local LINE = m.C(1 - NL)
-
-return function (str)
- local MATCH = m.Ct((LINE * NL)^0 * LINE)
- return MATCH:match(str)
-end
diff --git a/script/parser/tokens.lua b/script/parser/tokens.lua
index 958f292e..a4de7f88 100644
--- a/script/parser/tokens.lua
+++ b/script/parser/tokens.lua
@@ -7,6 +7,11 @@ local Word = m.R('AZ', 'az', '__', '\x80\xff') * m.R('AZ', 'az', '09', '__', '
local Symbol = m.P'=='
+ m.P'~='
+ m.P'--'
+ -- non-standard:
+ + m.P'<<='
+ + m.P'>>='
+ + m.P'//='
+ -- end non-standard
+ m.P'<<'
+ m.P'>>'
+ m.P'<='
@@ -15,7 +20,7 @@ local Symbol = m.P'=='
+ m.P'...'
+ m.P'..'
+ m.P'::'
- -- incorrect
+ -- non-standard:
+ m.P'!='
+ m.P'&&'
+ m.P'||'
@@ -24,7 +29,12 @@ local Symbol = m.P'=='
+ m.P'+='
+ m.P'-='
+ m.P'*='
+ + m.P'%='
+ + m.P'&='
+ + m.P'|='
+ + m.P'^='
+ m.P'/='
+ -- end non-standard
-- singles
+ m.S'+-*/!#%^&()={}[]|\\\'":;<>,.?~`'
local Unknown = (1 - Number - Word - Symbol - Sp - Nl)^1
diff --git a/script/plugin.lua b/script/plugin.lua
index 145abe74..870b68b6 100644
--- a/script/plugin.lua
+++ b/script/plugin.lua
@@ -4,6 +4,8 @@ local client = require 'client'
local lang = require 'language'
local await = require 'await'
local scope = require 'workspace.scope'
+local ws = require 'workspace'
+local fs = require 'bee.filesystem'
---@class plugin
local m = {}
@@ -69,20 +71,35 @@ local function checkTrustLoad(scp)
return true
end
----@param scp scope
-function m.init(scp)
+---@param uri uri
+local function initPlugin(uri)
await.call(function () ---@async
- local ws = require 'workspace'
+ local scp = scope.getScope(uri)
local interface = {}
scp:set('pluginInterface', interface)
+ if not scp.uri then
+ return
+ end
+
local pluginPath = ws.getAbsolutePath(scp.uri, config.get(scp.uri, 'Lua.runtime.plugin'))
log.info('plugin path:', pluginPath)
if not pluginPath then
return
end
+
+ --Adding the plugins path to package.path allows for requires in files
+ --to find files relative to itself.
+ local oldPath = package.path
+ local path = fs.path(pluginPath):parent_path() / '?.lua'
+ if not package.path:find(path:string(), 1, true) then
+ package.path = package.path .. ';' .. path:string()
+ end
+
local pluginLua = util.loadFile(pluginPath)
if not pluginLua then
+ log.warn('plugin not found:', pluginPath)
+ package.path = oldPath
return
end
@@ -98,7 +115,8 @@ function m.init(scp)
if not client.isVSCode() and not checkTrustLoad(scp) then
return
end
- local suc, err = xpcall(f, log.error, f)
+ local pluginArgs = config.get(scp.uri, 'Lua.runtime.pluginArgs')
+ local suc, err = xpcall(f, log.error, f, uri, pluginArgs)
if not suc then
m.showError(scp, err)
return
@@ -108,4 +126,10 @@ function m.init(scp)
end)
end
+ws.watch(function (ev, uri)
+ if ev == 'startReload' then
+ initPlugin(uri)
+ end
+end)
+
return m
diff --git a/script/progress.lua b/script/progress.lua
index b43ed05b..f1f371f5 100644
--- a/script/progress.lua
+++ b/script/progress.lua
@@ -11,10 +11,10 @@ local m = {}
m.map = {}
---@class progress
----@field _uri uri
+---@field _uri uri
+---@field _token integer
local mt = {}
mt.__index = mt
-mt._token = nil
mt._title = nil
mt._message = nil
mt._removed = false
diff --git a/script/proto/converter.lua b/script/proto/converter.lua
index 9c75f056..3f5ddebc 100644
--- a/script/proto/converter.lua
+++ b/script/proto/converter.lua
@@ -13,7 +13,7 @@ local function rawPackPosition(uri, pos)
if col > 0 then
local state = files.getState(uri)
local text = files.getText(uri)
- if text then
+ if state and text then
local lineOffset = state.lines[row]
if lineOffset then
local start = lineOffset
diff --git a/script/proto/define.lua b/script/proto/define.lua
index fb60c56c..ecdaf306 100644
--- a/script/proto/define.lua
+++ b/script/proto/define.lua
@@ -1,3 +1,5 @@
+local diag = require 'proto.diagnostic'
+
local m = {}
--- 诊断等级
@@ -8,122 +10,21 @@ m.DiagnosticSeverity = {
Hint = 4,
}
----@alias DiagnosticDefaultSeverity
----| 'Hint'
----| 'Information'
----| 'Warning'
----| 'Error'
-
---- 诊断类型与默认等级
----@type table<string, DiagnosticDefaultSeverity>
-m.DiagnosticDefaultSeverity = {
- ['unused-local'] = 'Hint',
- ['unused-function'] = 'Hint',
- ['undefined-global'] = 'Warning',
- ['undefined-field'] = 'Warning',
- ['global-in-nil-env'] = 'Warning',
- ['unused-label'] = 'Hint',
- ['unused-vararg'] = 'Hint',
- ['trailing-space'] = 'Hint',
- ['redefined-local'] = 'Hint',
- ['newline-call'] = 'Information',
- ['newfield-call'] = 'Warning',
- ['redundant-parameter'] = 'Warning',
- ['missing-parameter'] = 'Warning',
- ['redundant-return'] = 'Warning',
- ['ambiguity-1'] = 'Warning',
- ['lowercase-global'] = 'Information',
- ['undefined-env-child'] = 'Information',
- ['duplicate-index'] = 'Warning',
- ['duplicate-set-field'] = 'Warning',
- ['empty-block'] = 'Hint',
- ['redundant-value'] = 'Warning',
- ['code-after-break'] = 'Hint',
- ['unbalanced-assignments'] = 'Warning',
- ['close-non-object'] = 'Warning',
- ['count-down-loop'] = 'Warning',
- ['no-unknown'] = 'Information',
- ['deprecated'] = 'Warning',
- ['different-requires'] = 'Warning',
- ['await-in-sync'] = 'Warning',
- ['not-yieldable'] = 'Warning',
- ['discard-returns'] = 'Warning',
- ['need-check-nil'] = 'Warning',
- ['type-check'] = 'Warning',
-
- ['duplicate-doc-alias'] = 'Warning',
- ['undefined-doc-class'] = 'Warning',
- ['undefined-doc-name'] = 'Warning',
- ['circle-doc-class'] = 'Warning',
- ['undefined-doc-param'] = 'Warning',
- ['duplicate-doc-param'] = 'Warning',
- ['doc-field-no-class'] = 'Warning',
- ['duplicate-doc-field'] = 'Warning',
- ['unknown-diag-code'] = 'Warning',
-
- ['codestyle-check'] = "Warning",
+m.DiagnosticFileStatus = {
+ Any = 1,
+ Opened = 2,
+ None = 3,
}
----@alias DiagnosticDefaultNeededFileStatus
----| 'Any'
----| 'Opened'
----| 'None'
-
--- 文件状态
-m.FileStatus = {
- Any = 1,
- Opened = 2,
-}
+--- 诊断类型与默认等级
+m.DiagnosticDefaultSeverity = diag.getDefaultSeverity()
--- 诊断类型与需要的文件状态(可以控制只分析打开的文件、还是所有文件)
----@type table<string, DiagnosticDefaultNeededFileStatus>
-m.DiagnosticDefaultNeededFileStatus = {
- ['unused-local'] = 'Opened',
- ['unused-function'] = 'Opened',
- ['undefined-global'] = 'Any',
- ['undefined-field'] = 'Opened',
- ['global-in-nil-env'] = 'Any',
- ['unused-label'] = 'Opened',
- ['unused-vararg'] = 'Opened',
- ['trailing-space'] = 'Opened',
- ['redefined-local'] = 'Opened',
- ['newline-call'] = 'Any',
- ['newfield-call'] = 'Any',
- ['redundant-parameter'] = 'Opened',
- ['missing-parameter'] = 'Opened',
- ['redundant-return'] = 'Opened',
- ['ambiguity-1'] = 'Any',
- ['lowercase-global'] = 'Any',
- ['undefined-env-child'] = 'Any',
- ['duplicate-index'] = 'Any',
- ['duplicate-set-field'] = 'Any',
- ['empty-block'] = 'Opened',
- ['redundant-value'] = 'Opened',
- ['code-after-break'] = 'Opened',
- ['unbalanced-assignments'] = 'Any',
- ['close-non-object'] = 'Any',
- ['count-down-loop'] = 'Any',
- ['no-unknown'] = 'None',
- ['deprecated'] = 'Opened',
- ['different-requires'] = 'Any',
- ['await-in-sync'] = 'None',
- ['not-yieldable'] = 'None',
- ['discard-returns'] = 'Opened',
- ['need-check-nil'] = 'Opened',
- ['type-check'] = 'None',
+m.DiagnosticDefaultNeededFileStatus = diag.getDefaultStatus()
- ['duplicate-doc-alias'] = 'Any',
- ['undefined-doc-class'] = 'Any',
- ['undefined-doc-name'] = 'Any',
- ['circle-doc-class'] = 'Any',
- ['undefined-doc-param'] = 'Any',
- ['duplicate-doc-param'] = 'Any',
- ['doc-field-no-class'] = 'Any',
- ['duplicate-doc-field'] = 'Any',
- ['unknown-diag-code'] = 'Any',
+m.DiagnosticDefaultGroupSeverity = diag.getGroupSeverity()
- ['codestyle-check'] = 'None',
-}
+m.DiagnosticDefaultGroupFileStatus = diag.getGroupStatus()
--- 诊断报告标签
m.DiagnosticTag = {
@@ -260,24 +161,27 @@ m.TokenTypes = {
["number"] = 19,
["regexp"] = 20,
["operator"] = 21,
+ ["decorator"] = 22,
}
m.BuiltIn = {
- ['basic'] = 'default',
- ['bit'] = 'default',
- ['bit32'] = 'default',
- ['builtin'] = 'default',
- ['coroutine'] = 'default',
- ['debug'] = 'default',
- ['ffi'] = 'default',
- ['io'] = 'default',
- ['jit'] = 'default',
- ['math'] = 'default',
- ['os'] = 'default',
- ['package'] = 'default',
- ['string'] = 'default',
- ['table'] = 'default',
- ['utf8'] = 'default',
+ ['basic'] = 'default',
+ ['bit'] = 'default',
+ ['bit32'] = 'default',
+ ['builtin'] = 'default',
+ ['coroutine'] = 'default',
+ ['debug'] = 'default',
+ ['ffi'] = 'default',
+ ['io'] = 'default',
+ ['jit'] = 'default',
+ ['math'] = 'default',
+ ['os'] = 'default',
+ ['package'] = 'default',
+ ['string'] = 'default',
+ ['table'] = 'default',
+ ['table.new'] = 'default',
+ ['table.clear'] = 'default',
+ ['utf8'] = 'default',
}
m.InlayHintKind = {
diff --git a/script/proto/diagnostic.lua b/script/proto/diagnostic.lua
new file mode 100644
index 00000000..9b0303cc
--- /dev/null
+++ b/script/proto/diagnostic.lua
@@ -0,0 +1,267 @@
+local util = require 'utility'
+
+---@class proto.diagnostic
+local m = {}
+
+---@alias DiagnosticSeverity
+---| 'Hint'
+---| 'Information'
+---| 'Warning'
+---| 'Error'
+
+---@alias DiagnosticNeededFileStatus
+---| 'Any'
+---| 'Opened'
+---| 'None'
+
+---@class proto.diagnostic.info
+---@field severity DiagnosticSeverity
+---@field status DiagnosticNeededFileStatus
+---@field group string
+
+m.diagnosticDatas = {}
+m.diagnosticGroups = {}
+
+function m.register(names)
+ ---@param info proto.diagnostic.info
+ return function (info)
+ for _, name in ipairs(names) do
+ m.diagnosticDatas[name] = {
+ severity = info.severity,
+ status = info.status,
+ }
+ if not m.diagnosticGroups[info.group] then
+ m.diagnosticGroups[info.group] = {}
+ end
+ m.diagnosticGroups[info.group][name] = true
+ end
+ end
+end
+
+m.register {
+ 'unused-local',
+ 'unused-function',
+ 'unused-label',
+ 'unused-vararg',
+ 'trailing-space',
+ 'redundant-return',
+ 'empty-block',
+ 'code-after-break',
+ 'unreachable-code',
+} {
+ group = 'unused',
+ severity = 'Hint',
+ status = 'Opened',
+}
+
+m.register {
+ 'redundant-value',
+ 'unbalanced-assignments',
+ 'redundant-parameter',
+ 'missing-parameter',
+ 'missing-return-value',
+ 'redundant-return-value',
+ 'missing-return',
+} {
+ group = 'unbalanced',
+ severity = 'Warning',
+ status = 'Any',
+}
+
+m.register {
+ 'need-check-nil',
+ 'undefined-field',
+ 'cast-local-type',
+ 'assign-type-mismatch',
+ 'param-type-mismatch',
+ 'cast-type-mismatch',
+ 'return-type-mismatch',
+} {
+ group = 'type-check',
+ severity = 'Warning',
+ status = 'Opened',
+}
+
+m.register {
+ 'duplicate-doc-alias',
+ 'undefined-doc-class',
+ 'undefined-doc-name',
+ 'circle-doc-class',
+ 'undefined-doc-param',
+ 'duplicate-doc-param',
+ 'doc-field-no-class',
+ 'duplicate-doc-field',
+ 'unknown-diag-code',
+ 'unknown-cast-variable',
+ 'unknown-operator',
+} {
+ group = 'luadoc',
+ severity = 'Warning',
+ status = 'Any',
+}
+
+m.register {
+ 'codestyle-check'
+} {
+ group = 'codestyle',
+ severity = 'Warning',
+ status = 'None',
+}
+
+m.register {
+ 'spell-check'
+} {
+ group = 'codestyle',
+ severity = 'Information',
+ status = 'None',
+}
+
+m.register {
+ 'newline-call',
+ 'newfield-call',
+ 'ambiguity-1',
+ 'count-down-loop',
+ 'different-requires',
+} {
+ group = 'ambiguity',
+ severity = 'Warning',
+ status = 'Any',
+}
+
+m.register {
+ 'await-in-sync',
+ 'not-yieldable',
+} {
+ group = 'await',
+ severity = 'Warning',
+ status = 'None',
+}
+
+m.register {
+ 'no-unknown',
+} {
+ group = 'strong',
+ severity = 'Warning',
+ status = 'None',
+}
+
+m.register {
+ 'redefined-local',
+} {
+ group = 'redefined',
+ severity = 'Hint',
+ status = 'Opened',
+}
+
+m.register {
+ 'undefined-global',
+ 'global-in-nil-env',
+} {
+ group = 'global',
+ severity = 'Warning',
+ status = 'Any',
+}
+
+m.register {
+ 'lowercase-global',
+ 'undefined-env-child',
+} {
+ group = 'global',
+ severity = 'Information',
+ status = 'Any',
+}
+
+m.register {
+ 'duplicate-index',
+ 'duplicate-set-field',
+} {
+ group = 'duplicate',
+ severity = 'Warning',
+ status = 'Any',
+}
+
+m.register {
+ 'close-non-object',
+ 'deprecated',
+ 'discard-returns',
+} {
+ group = 'strict',
+ severity = 'Warning',
+ status = 'Any',
+}
+
+---@return table<string, DiagnosticSeverity>
+function m.getDefaultSeverity()
+ local severity = {}
+ for name, info in pairs(m.diagnosticDatas) do
+ severity[name] = info.severity
+ end
+ return severity
+end
+
+---@return table<string, DiagnosticNeededFileStatus>
+function m.getDefaultStatus()
+ local status = {}
+ for name, info in pairs(m.diagnosticDatas) do
+ status[name] = info.status
+ end
+ return status
+end
+
+function m.getGroupSeverity()
+ local group = {}
+ for name in pairs(m.diagnosticGroups) do
+ group[name] = 'Fallback'
+ end
+ return group
+end
+
+function m.getGroupStatus()
+ local group = {}
+ for name in pairs(m.diagnosticGroups) do
+ group[name] = 'Fallback'
+ end
+ return group
+end
+
+---@param name string
+---@return string[]
+m.getGroups = util.cacheReturn(function (name)
+ local groups = {}
+ for groupName, nameMap in pairs(m.diagnosticGroups) do
+ if nameMap[name] then
+ groups[#groups+1] = groupName
+ end
+ end
+ table.sort(groups)
+ return groups
+end)
+
+---@return table<string, true>
+function m.getDiagAndErrNameMap()
+ if not m._diagAndErrNames then
+ local names = {}
+ for name in pairs(m.getDefaultSeverity()) do
+ names[name] = true
+ end
+ local path = package.searchpath('parser.compile', package.path)
+ if path then
+ local f = io.open(path)
+ if f then
+ for line in f:lines() do
+ local name = line:match([=[type%s*=%s*['"](%u[%u_]+%u)['"]]=])
+ if name then
+ local id = name:lower():gsub('_', '-')
+ names[id] = true
+ end
+ end
+ f:close()
+ end
+ end
+ table.sort(names)
+ m._diagAndErrNames = names
+ end
+ return m._diagAndErrNames
+end
+
+return m
diff --git a/script/provider/build-meta.lua b/script/provider/build-meta.lua
new file mode 100644
index 00000000..baabe39c
--- /dev/null
+++ b/script/provider/build-meta.lua
@@ -0,0 +1,155 @@
+local fs = require 'bee.filesystem'
+local config = require 'config'
+local util = require 'utility'
+local await = require 'await'
+local progress = require 'progress'
+local lang = require 'language'
+
+local m = {}
+
+---@class meta
+---@field root string
+---@field classes meta.class[]
+
+---@class meta.class
+---@field name string
+---@field comment string
+---@field location string
+---@field namespace string
+---@field baseClass string
+---@field attribute string
+---@field integerface string[]
+---@field fields meta.field[]
+---@field methods meta.method[]
+
+---@class meta.field
+---@field name string
+---@field typeName string
+---@field comment string
+---@field location string
+
+---@class meta.method
+---@field name string
+---@field comment string
+---@field location string
+---@field isStatic boolean
+---@field returnTypeName string
+---@field params {name: string, typeName: string}[]
+
+---@param ... string
+---@return string
+local function mergeString(...)
+ local buf = {}
+ for i = 1, select('#', ...) do
+ local str = select(i, ...)
+ if str ~= '' then
+ buf[#buf+1] = str
+ end
+ end
+ return table.concat(buf, '.')
+end
+
+local function addComments(lines, comment)
+ if comment == '' then
+ return
+ end
+ lines[#lines+1] = '--'
+ lines[#lines+1] = '--' .. comment:gsub('[\r\n]+$', ''):gsub('\n', '\n--')
+ lines[#lines+1] = '--'
+end
+
+---@param lines string[]
+---@param name string
+---@param method meta.method
+local function addMethod(lines, name, method)
+ if not method.name:match '^[%a_][%w_]*$' then
+ return
+ end
+ addComments(lines, method.comment)
+ lines[#lines+1] = ('---@source %s'):format(method.location:gsub('#', ':'))
+ local params = {}
+ for _, param in ipairs(method.params) do
+ lines[#lines+1] = ('---@param %s %s'):format(param.name, param.typeName)
+ params[#params+1] = param.name
+ end
+ if method.returnTypeName ~= ''
+ and method.returnTypeName ~= 'Void' then
+ lines[#lines+1] = ('---@return %s'):format(method.returnTypeName)
+ end
+ lines[#lines+1] = ('function %s%s%s(%s) end'):format(
+ name,
+ method.isStatic and ':' or '.',
+ method.name,
+ table.concat(params, ', ')
+ )
+ lines[#lines+1] = ''
+end
+
+---@param root string
+---@param class meta.class
+---@return string
+local function buildText(root, class)
+ local lines = {}
+
+ addComments(lines, class.comment)
+ lines[#lines+1] = ('---@source %s'):format(class.location:gsub('#', ':'))
+ if class.baseClass == '' then
+ lines[#lines+1] = ('---@class %s'):format(mergeString(class.namespace, class.name))
+ else
+ lines[#lines+1] = ('---@class %s: %s'):format(mergeString(class.namespace, class.name), class.baseClass)
+ end
+
+ for _, field in ipairs(class.fields) do
+ addComments(lines, field.comment)
+ lines[#lines+1] = ('---@source %s'):format(field.location:gsub('#', ':'))
+ lines[#lines+1] = ('---@field %s %s'):format(field.name, field.typeName)
+ end
+
+ lines[#lines+1] = ('---@source %s'):format(class.location:gsub('#', ':'))
+ local name = mergeString(root, class.namespace, class.name)
+ lines[#lines+1] = ('%s = {}'):format(name)
+ lines[#lines+1] = ''
+
+ for _, method in ipairs(class.methods) do
+ addMethod(lines, name, method)
+ end
+
+ return table.concat(lines, '\n')
+end
+
+local function buildRootText(api)
+ local lines = {}
+
+ lines[#lines+1] = ('---@class %s'):format(api.root)
+ lines[#lines+1] = ('%s = {}'):format(api.root)
+ lines[#lines+1] = ''
+ return table.concat(lines, '\n')
+end
+
+---@async
+---@param path string
+---@param api meta
+function m.build(path, api)
+
+ local files = util.multiTable(2, function ()
+ return { '---@meta' }
+ end)
+
+ files[api.root][#files[api.root]+1] = buildRootText(api)
+
+ local proc <close> = progress.create(nil, lang.script.WINDOW_PROCESSING_BUILD_META, 0.5)
+ for i, class in ipairs(api.classes) do
+ local space = class.namespace ~= '' and class.namespace or api.root
+ proc:setMessage(space)
+ proc:setPercentage(i / #api.classes * 100)
+ local text = buildText(api.root, class)
+ files[space][#files[space]+1] = text
+ await.delay()
+ end
+
+ for space, texts in pairs(files) do
+ util.saveFile(path .. '/' .. space .. '.lua', table.concat(texts, '\n\n'))
+ end
+end
+
+return m
diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua
index 15b08d49..46ea600f 100644
--- a/script/provider/diagnostic.lua
+++ b/script/provider/diagnostic.lua
@@ -14,6 +14,10 @@ local loading = require 'workspace.loading'
local scope = require 'workspace.scope'
local time = require 'bee.time'
local ltable = require 'linked-table'
+local furi = require 'file-uri'
+local json = require 'json'
+local fw = require 'filewatch'
+local vm = require 'vm.vm'
---@class diagnosticProvider
local m = {}
@@ -29,6 +33,9 @@ end
local function buildSyntaxError(uri, err)
local text = files.getText(uri)
+ if not text then
+ return
+ end
local message = lang.script('PARSER_' .. err.type, err.info)
if err.version then
@@ -80,10 +87,14 @@ local function buildDiagnostic(uri, diag)
relatedInformation = {}
for _, rel in ipairs(diag.related) do
local rtext = files.getText(rel.uri)
+ if not rtext then
+ goto CONTINUE
+ end
relatedInformation[#relatedInformation+1] = {
message = rel.message or rtext:sub(rel.start, rel.finish),
location = converter.location(rel.uri, converter.packRange(rel.uri, rel.start, rel.finish))
}
+ ::CONTINUE::
end
end
@@ -136,9 +147,9 @@ local function mergeDiags(a, b, c)
end
-- enable `push`, disable `clear`
-function m.clear(uri)
+function m.clear(uri, force)
await.close('diag:' .. uri)
- if m.cache[uri] == nil then
+ if m.cache[uri] == nil and not force then
return
end
m.cache[uri] = nil
@@ -149,14 +160,27 @@ function m.clear(uri)
log.info('clearDiagnostics', uri)
end
--- enable `push` and `send`
-function m.clearCache(uri)
- m.cache[uri] = false
+function m.clearCacheExcept(uris)
+ local excepts = {}
+ for _, uri in ipairs(uris) do
+ excepts[uri] = true
+ end
+ for uri in pairs(m.cache) do
+ if not excepts[uri] then
+ m.cache[uri] = false
+ end
+ end
end
-function m.clearAll()
- for luri in pairs(m.cache) do
- m.clear(luri)
+function m.clearAll(force)
+ if force then
+ for luri in files.eachFile() do
+ m.clear(luri, force)
+ end
+ else
+ for luri in pairs(m.cache) do
+ m.clear(luri)
+ end
end
end
@@ -168,9 +192,11 @@ function m.syntaxErrors(uri, ast)
local results = {}
pcall(function ()
- local disables = config.get(uri, 'Lua.diagnostics.disable')
+ local disables = util.arrayToHash(config.get(uri, 'Lua.diagnostics.disable'))
for _, err in ipairs(ast.errs) do
- if not disables[err.type:lower():gsub('_', '-')] then
+ local id = err.type:lower():gsub('_', '-')
+ if not disables[id]
+ and not vm.isDiagDisabledAt(uri, err.start, id, true) then
results[#results+1] = buildSyntaxError(uri, err)
end
end
@@ -193,30 +219,45 @@ local function copyDiagsWithoutSyntax(diags)
end
---@async
-function m.doDiagnostic(uri, isScopeDiag)
+---@param uri uri
+---@return boolean
+local function isValid(uri)
if not config.get(uri, 'Lua.diagnostics.enable') then
- return
+ return false
end
if files.isLibrary(uri, true) then
local status = config.get(uri, 'Lua.diagnostics.libraryFiles')
if status == 'Disable' then
- return
+ return false
elseif status == 'Opened' then
if not files.isOpen(uri) then
- return
+ return false
end
end
end
if ws.isIgnored(uri) then
local status = config.get(uri, 'Lua.diagnostics.ignoredFiles')
if status == 'Disable' then
- return
+ return false
elseif status == 'Opened' then
if not files.isOpen(uri) then
- return
+ return false
end
end
end
+ local scheme = furi.split(uri)
+ local disableScheme = config.get(uri, 'Lua.diagnostics.disableScheme')
+ if util.arrayHas(disableScheme, scheme) then
+ return false
+ end
+ return true
+end
+
+---@async
+function m.doDiagnostic(uri, isScopeDiag)
+ if not isValid(uri) then
+ return
+ end
await.delay()
@@ -267,7 +308,7 @@ function m.doDiagnostic(uri, isScopeDiag)
xpcall(core, log.error, uri, isScopeDiag, function (result)
diags[#diags+1] = buildDiagnostic(uri, result)
- if not isScopeDiag and time.time() - lastPushClock >= 200 then
+ if not isScopeDiag and time.time() - lastPushClock >= 500 then
lastPushClock = time.time()
pushResult()
end
@@ -287,20 +328,78 @@ function m.doDiagnostic(uri, isScopeDiag)
pushResult()
end
+---@param uri uri
+function m.resendDiagnostic(uri)
+ local full = m.cache[uri]
+ if not full then
+ return
+ end
+
+ local version = files.getVersion(uri)
+
+ proto.notify('textDocument/publishDiagnostics', {
+ uri = uri,
+ version = version,
+ diagnostics = full,
+ })
+ log.debug('publishDiagnostics', uri, #full)
+end
+
+---@async
+---@return table|nil result
+---@return boolean? unchanged
+function m.pullDiagnostic(uri, isScopeDiag)
+ if not isValid(uri) then
+ return nil, util.equal(m.cache[uri], nil)
+ end
+
+ await.delay()
+
+ local state = files.getState(uri)
+ if not state then
+ return nil, util.equal(m.cache[uri], nil)
+ end
+
+ local prog <close> = progress.create(uri, lang.script.WINDOW_DIAGNOSING, 0.5)
+ prog:setMessage(ws.getRelativePath(uri))
+
+ local syntax = m.syntaxErrors(uri, state)
+ local diags = {}
+
+ xpcall(core, log.error, uri, isScopeDiag, function (result)
+ diags[#diags+1] = buildDiagnostic(uri, result)
+ end)
+
+ local full = mergeDiags(syntax, diags)
+
+ if util.equal(m.cache[uri], full) then
+ return full, true
+ end
+
+ m.cache[uri] = full
+
+ return full
+end
+
+---@param uri uri
function m.refresh(uri)
if not ws.isReady(uri) then
return
end
+
+ await.close('diag:' .. uri)
+ ---@async
+ await.call(function ()
+ await.setID('diag:' .. uri)
+ await.sleep(0.1)
+ xpcall(m.doDiagnostic, log.error, uri)
+ end)
+
local scp = scope.getScope(uri)
local scopeID = 'diagnosticsScope:' .. scp:getName()
- await.close('diag:' .. uri)
await.close(scopeID)
- await.call(function () ---@async
- if uri then
- await.setID('diag:' .. uri)
- await.sleep(0.1)
- xpcall(m.doDiagnostic, log.error, uri)
- end
+ ---@async
+ await.call(function ()
local delay = config.get(uri, 'Lua.diagnostics.workspaceDelay') / 1000
if delay < 0 then
return
@@ -359,7 +458,7 @@ local function askForDisable(uri)
end
---@async
-function m.awaitDiagnosticsScope(suri)
+function m.awaitDiagnosticsScope(suri, callback)
local scp = scope.getScope(suri)
while loading.count() > 0 do
await.sleep(1.0)
@@ -393,7 +492,7 @@ function m.awaitDiagnosticsScope(suri)
i = i + 1
bar:setMessage(('%d/%d'):format(i, #uris))
bar:setPercentage(i / #uris * 100)
- xpcall(m.doDiagnostic, log.error, uri, true)
+ callback(uri)
await.delay()
if cancelled then
log.info('Break workspace diagnostics')
@@ -416,13 +515,66 @@ function m.diagnosticsScope(uri, force)
local id = 'diagnosticsScope:' .. scp:getName()
await.close(id)
await.call(function () ---@async
- m.awaitDiagnosticsScope(uri)
+ await.sleep(0.0)
+ m.awaitDiagnosticsScope(uri, function (fileUri)
+ xpcall(m.doDiagnostic, log.error, fileUri, true)
+ end)
end, id)
end
+---@async
+function m.pullDiagnosticScope(callback)
+ local processing = 0
+
+ for _, scp in ipairs(scope.folders) do
+ if ws.isReady(scp.uri)
+ and config.get(scp.uri, 'Lua.diagnostics.enable') then
+ local id = 'diagnosticsScope:' .. scp:getName()
+ await.close(id)
+ await.call(function () ---@async
+ processing = processing + 1
+ local _ <close> = util.defer(function ()
+ processing = processing - 1
+ end)
+
+ local delay = config.get(scp.uri, 'Lua.diagnostics.workspaceDelay') / 1000
+ if delay < 0 then
+ return
+ end
+ print(delay)
+ await.sleep(math.max(delay, 0.2))
+ print('start')
+
+ m.awaitDiagnosticsScope(scp.uri, function (fileUri)
+ local suc, result, unchanged = xpcall(m.pullDiagnostic, log.error, fileUri, true)
+ if suc then
+ callback {
+ uri = fileUri,
+ result = result,
+ unchanged = unchanged,
+ version = files.getVersion(fileUri),
+ }
+ end
+ end)
+ end, id)
+ end
+ end
+
+ -- sleep for ever
+ while true do
+ await.sleep(1.0)
+ end
+end
+
+function m.refreshClient()
+ log.debug('Refresh client diagnostics')
+ proto.request('workspace/diagnostic/refresh', json.null)
+end
+
ws.watch(function (ev, uri)
if ev == 'reload' then
m.diagnosticsScope(uri)
+ m.refreshClient()
end
end)
@@ -434,7 +586,7 @@ files.watch(function (ev, uri) ---@async
m.refresh(uri)
elseif ev == 'open' then
if ws.isReady(uri) then
- m.clearCache(uri)
+ m.resendDiagnostic(uri)
xpcall(m.doDiagnostic, log.error, uri)
end
elseif ev == 'close' then
@@ -446,9 +598,20 @@ files.watch(function (ev, uri) ---@async
end)
config.watch(function (uri, key, value, oldValue)
- if key:find 'Lua.diagnostics' then
+ if util.stringStartWith(key, 'Lua.diagnostics')
+ or util.stringStartWith(key, 'Lua.spell') then
if value ~= oldValue then
m.diagnosticsScope(uri)
+ m.refreshClient()
+ end
+ end
+end)
+
+fw.event(function (ev, path)
+ if util.stringEndWith(path, '.editorconfig') then
+ for _, scp in ipairs(ws.folders) do
+ m.diagnosticsScope(scp.uri)
+ m.refreshClient()
end
end
end)
diff --git a/script/provider/formatting.lua b/script/provider/formatting.lua
index 73b6608d..4ec5545a 100644
--- a/script/provider/formatting.lua
+++ b/script/provider/formatting.lua
@@ -19,7 +19,7 @@ local updateType = {
Deleted = 3,
}
-fw.event(function (ev, path)
+fw.event(function(ev, path)
if util.stringEndWith(path, '.editorconfig') then
for uri, fsPath in pairs(loadedUris) do
loadedUris[uri] = nil
@@ -30,17 +30,9 @@ fw.event(function (ev, path)
end
end
end
- for _, scp in ipairs(ws.folders) do
- diagnostics.diagnosticsScope(scp.uri)
- end
end
end)
-config.watch(function (uri, key, value)
- if key == "Lua.format.defaultConfig" then
- codeFormat.set_default_config(value)
- end
-end)
local m = {}
@@ -51,6 +43,7 @@ function m.updateConfig(uri)
if not m.loadedDefaultConfig then
m.loadedDefaultConfig = true
codeFormat.set_default_config(config.get(uri, 'Lua.format.defaultConfig'))
+ m.updateNonStandardSymbols(config.get(nil, 'Lua.runtime.nonstandardSymbol'))
end
local currentUri = uri
@@ -64,7 +57,7 @@ function m.updateConfig(uri)
local currentPath = furi.decode(currentUri)
local editorConfigFSPath = fs.path(currentPath) / '.editorconfig'
if fs.exists(editorConfigFSPath) then
- loadedUris[uri] = editorConfigFSPath
+ loadedUris[currentUri] = editorConfigFSPath
local status, err = codeFormat.update_config(updateType.Created, currentUri, editorConfigFSPath:string())
if not status and err then
log.error(err)
@@ -83,4 +76,30 @@ function m.updateConfig(uri)
end
end
+---@param symbols? string[]
+function m.updateNonStandardSymbols(symbols)
+ if symbols == nil then
+ return
+ end
+
+ local eqTokens = {}
+ for _, token in ipairs(symbols) do
+ if token:find("=") and token ~= "!=" then
+ table.insert(eqTokens, token)
+ end
+ end
+
+ if #eqTokens ~= 0 then
+ codeFormat.set_nonstandard_symbol('=', eqTokens)
+ end
+end
+
+config.watch(function(uri, key, value)
+ if key == "Lua.format.defaultConfig" then
+ codeFormat.set_default_config(value)
+ elseif key == "Lua.runtime.nonstandardSymbol" then
+ m.updateNonStandardSymbols(value)
+ end
+end)
+
return m
diff --git a/script/provider/markdown.lua b/script/provider/markdown.lua
index 6b7d24c8..50716073 100644
--- a/script/provider/markdown.lua
+++ b/script/provider/markdown.lua
@@ -10,7 +10,7 @@ function mt:__tostring()
end
---@param language string
----@param text string|markdown
+---@param text? string|markdown
function mt:add(language, text)
if not text then
return self
@@ -40,7 +40,16 @@ function mt:splitLine()
return self
end
-function mt:string()
+function mt:emptyLine()
+ self._cacheResult = nil
+ self[#self+1] = {
+ type = 'emptyline',
+ }
+ return self
+end
+
+---@return string
+function mt:string(nl)
if self._cacheResult then
return self._cacheResult
end
@@ -59,6 +68,11 @@ function mt:string()
lines[#lines+1] = ''
lines[#lines+1] = '---'
end
+ elseif obj.type == 'emptyline' then
+ if #lines > 0
+ and lines[#lines] ~= '' then
+ lines[#lines+1] = ''
+ end
elseif obj.type == 'markdown' then
concat(obj.markdown)
else
@@ -80,6 +94,10 @@ function mt:string()
if lines[#lines] ~= '' then
lines[#lines+1] = ''
end
+ elseif last == '---' then
+ if lines[#lines] ~= '' then
+ lines[#lines+1] = ''
+ end
end
end
lines[#lines+1] = obj.text
@@ -101,7 +119,7 @@ function mt:string()
end
end
- local result = table.concat(lines, '\n')
+ local result = table.concat(lines, nl or '\n')
self._cacheResult = result
return result
end
diff --git a/script/provider/provider.lua b/script/provider/provider.lua
index 3d012757..018db0c3 100644
--- a/script/provider/provider.lua
+++ b/script/provider/provider.lua
@@ -21,6 +21,8 @@ local furi = require 'file-uri'
local inspect = require 'inspect'
local markdown = require 'provider.markdown'
local guide = require 'parser.guide'
+local fs = require 'bee.filesystem'
+local jumpSource = require 'core.jump-source'
---@async
local function updateConfig(uri)
@@ -28,7 +30,7 @@ local function updateConfig(uri)
local specified = cfgLoader.loadLocalConfig(uri, CONFIGPATH)
if specified then
log.info('Load config from specified', CONFIGPATH)
- log.debug(inspect(specified))
+ log.info(inspect(specified))
-- watch directory
filewatch.watch(workspace.getAbsolutePath(uri, CONFIGPATH):gsub('[^/\\]+$', ''))
config.update(scope.override, specified)
@@ -38,14 +40,14 @@ local function updateConfig(uri)
local clientConfig = cfgLoader.loadClientConfig(folder.uri)
if clientConfig then
log.info('Load config from client', folder.uri)
- log.debug(inspect(clientConfig))
+ log.info(inspect(clientConfig))
end
local rc = cfgLoader.loadRCConfig(folder.uri, '.luarc.json')
or cfgLoader.loadRCConfig(folder.uri, '.luarc.jsonc')
if rc then
log.info('Load config from .luarc.json/.luarc.jsonc', folder.uri)
- log.debug(inspect(rc))
+ log.info(inspect(rc))
end
config.update(folder, clientConfig, rc)
@@ -53,7 +55,7 @@ local function updateConfig(uri)
local global = cfgLoader.loadClientConfig()
log.info('Load config from client', 'fallback')
- log.debug(inspect(global))
+ log.info(inspect(global))
config.update(scope.fallback, global)
end
@@ -149,7 +151,6 @@ m.register 'initialized'{
})
end
client.setReady()
- library.init()
workspace.init()
return true
end
@@ -234,11 +235,35 @@ m.register 'workspace/didRenameFiles' {
end
}
+m.register 'workspace/didChangeWorkspaceFolders' {
+ capability = {
+ workspace = {
+ workspaceFolders = {
+ supported = true,
+ changeNotifications = true,
+ },
+ },
+ },
+ ---@async
+ function (params)
+ log.debug('workspace/didChangeWorkspaceFolders', inspect(params))
+ for _, folder in ipairs(params.event.added) do
+ workspace.create(folder.uri)
+ updateConfig()
+ workspace.reload(scope.getScope(folder.uri))
+ end
+ for _, folder in ipairs(params.event.removed) do
+ workspace.remove(folder.uri)
+ end
+ end
+}
+
m.register 'textDocument/didOpen' {
function (params)
- local doc = params.textDocument
- local scheme = furi.split(doc.uri)
- if scheme ~= 'file' then
+ local doc = params.textDocument
+ local scheme = furi.split(doc.uri)
+ local supports = config.get(doc.uri, 'Lua.workspace.supportScheme')
+ if not util.arrayHas(supports, scheme) then
return
end
local uri = files.getRealUri(doc.uri)
@@ -264,15 +289,21 @@ m.register 'textDocument/didClose' {
}
m.register 'textDocument/didChange' {
+ ---@async
function (params)
- local doc = params.textDocument
- local scheme = furi.split(doc.uri)
- if scheme ~= 'file' then
+ local doc = params.textDocument
+ local scheme = furi.split(doc.uri)
+ local supports = config.get(doc.uri, 'Lua.workspace.supportScheme')
+ if not util.arrayHas(supports, scheme) then
return
end
local changes = params.contentChanges
local uri = files.getRealUri(doc.uri)
- local text = files.getOriginText(uri) or ''
+ local text = files.getOriginText(uri)
+ if not text then
+ files.setText(uri, pub.awaitTask('loadFile', furi.decode(uri)), false)
+ return
+ end
local rows = files.getCachedRows(uri)
text, rows = tm(text, rows, changes)
files.setText(uri, text, true, function (file)
@@ -310,7 +341,7 @@ m.register 'textDocument/hover' {
end
local pos = converter.unpackPosition(uri, params.position)
local hover, source = core.byUri(uri, pos)
- if not hover then
+ if not hover or not source then
return nil
end
return {
@@ -346,18 +377,16 @@ m.register 'textDocument/definition' {
for i, info in ipairs(result) do
local targetUri = info.uri
if targetUri then
- if files.exists(targetUri) then
- if client.getAbility 'textDocument.definition.linkSupport' then
- response[i] = converter.locationLink(targetUri
- , converter.packRange(targetUri, info.target.start, info.target.finish)
- , converter.packRange(targetUri, info.target.start, info.target.finish)
- , converter.packRange(uri, info.source.start, info.source.finish)
- )
- else
- response[i] = converter.location(targetUri
- , converter.packRange(targetUri, info.target.start, info.target.finish)
- )
- end
+ if client.getAbility 'textDocument.definition.linkSupport' then
+ response[i] = converter.locationLink(targetUri
+ , converter.packRange(targetUri, info.target.start, info.target.finish)
+ , converter.packRange(targetUri, info.target.start, info.target.finish)
+ , converter.packRange(uri, info.source.start, info.source.finish)
+ )
+ else
+ response[i] = converter.location(targetUri
+ , converter.packRange(targetUri, info.target.start, info.target.finish)
+ )
end
end
end
@@ -388,18 +417,16 @@ m.register 'textDocument/typeDefinition' {
for i, info in ipairs(result) do
local targetUri = info.uri
if targetUri then
- if files.exists(targetUri) then
- if client.getAbility 'textDocument.typeDefinition.linkSupport' then
- response[i] = converter.locationLink(targetUri
- , converter.packRange(targetUri, info.target.start, info.target.finish)
- , converter.packRange(targetUri, info.target.start, info.target.finish)
- , converter.packRange(uri, info.source.start, info.source.finish)
- )
- else
- response[i] = converter.location(targetUri
- , converter.packRange(targetUri, info.target.start, info.target.finish)
- )
- end
+ if client.getAbility 'textDocument.typeDefinition.linkSupport' then
+ response[i] = converter.locationLink(targetUri
+ , converter.packRange(targetUri, info.target.start, info.target.finish)
+ , converter.packRange(targetUri, info.target.start, info.target.finish)
+ , converter.packRange(uri, info.source.start, info.source.finish)
+ )
+ else
+ response[i] = converter.location(targetUri
+ , converter.packRange(targetUri, info.target.start, info.target.finish)
+ )
end
end
end
@@ -902,29 +929,35 @@ local function toArray(map)
return array
end
-m.register 'textDocument/semanticTokens/full' {
- capability = {
- semanticTokensProvider = {
- legend = {
- tokenTypes = toArray(define.TokenTypes),
- tokenModifiers = toArray(define.TokenModifiers),
- },
- full = true,
- },
- },
- ---@async
- function (params)
- log.debug('textDocument/semanticTokens/full')
- local uri = files.getRealUri(params.textDocument.uri)
- workspace.awaitReady(uri)
- local _ <close> = progress.create(uri, lang.script.WINDOW_PROCESSING_SEMANTIC_FULL, 0.5)
- local core = require 'core.semantic-tokens'
- local results = core(uri, 0, math.huge)
- return {
- data = results
- }
+client.event(function (ev)
+ if ev == 'init' then
+ if not client.isVSCode() then
+ m.register 'textDocument/semanticTokens/full' {
+ capability = {
+ semanticTokensProvider = {
+ legend = {
+ tokenTypes = toArray(define.TokenTypes),
+ tokenModifiers = toArray(define.TokenModifiers),
+ },
+ full = true,
+ },
+ },
+ ---@async
+ function (params)
+ log.debug('textDocument/semanticTokens/full')
+ local uri = files.getRealUri(params.textDocument.uri)
+ workspace.awaitReady(uri)
+ local _ <close> = progress.create(uri, lang.script.WINDOW_PROCESSING_SEMANTIC_FULL, 0.5)
+ local core = require 'core.semantic-tokens'
+ local results = core(uri, 0, math.huge)
+ return {
+ data = results
+ }
+ end
+ }
+ end
end
-}
+end)
m.register 'textDocument/semanticTokens/range' {
capability = {
@@ -988,6 +1021,40 @@ m.register 'textDocument/foldingRange' {
end
}
+m.register 'textDocument/documentColor' {
+ capability = {
+ colorProvider = true
+ },
+ ---@async
+ function (params)
+ local color = require 'core.color'
+ local uri = files.getRealUri(params.textDocument.uri)
+ workspace.awaitReady(uri)
+ if not files.exists(uri) then
+ return nil
+ end
+ local colors = color.colors(uri)
+ if not colors then
+ return nil
+ end
+ local results = {}
+ for _, colorValue in ipairs(colors) do
+ results[#results+1] = {
+ range = converter.packRange(uri, colorValue.start, colorValue.finish),
+ color = colorValue.color
+ }
+ end
+ return results
+ end
+}
+
+m.register 'textDocument/colorPresentation' {
+ function (params)
+ local color = (require 'core.color').colorToText(params.color)
+ return {{label = color}}
+ end
+}
+
m.register 'window/workDoneProgress/cancel' {
function (params)
log.debug('close proto(cancel):', params.token)
@@ -1001,6 +1068,7 @@ m.register '$/status/click' {
local titleDiagnostic = lang.script.WINDOW_LUA_STATUS_DIAGNOSIS_TITLE
local result = client.awaitRequestMessage('Info', lang.script.WINDOW_LUA_STATUS_DIAGNOSIS_MSG, {
titleDiagnostic,
+ DEVELOP and 'Restart Server',
})
if not result then
return
@@ -1010,6 +1078,10 @@ m.register '$/status/click' {
for _, scp in ipairs(workspace.folders) do
diagnostic.diagnosticsScope(scp.uri, true)
end
+ elseif result == 'Restart Server' then
+ local diag = require 'provider.diagnostic'
+ diag.clearAll(true)
+ os.exit(0, true)
end
end
}
@@ -1210,6 +1282,123 @@ m.register 'inlayHint/resolve' {
end
}
+m.register 'textDocument/diagnostic' {
+ preview = true,
+ capability = {
+ diagnosticProvider = {
+ identifier = 'identifier',
+ interFileDependencies = true,
+ workspaceDiagnostics = false,
+ }
+ },
+ ---@async
+ function (params)
+ local uri = files.getRealUri(params.textDocument.uri)
+ workspace.awaitReady(uri)
+ local core = require 'provider.diagnostic'
+ -- TODO: do some trick
+ core.doDiagnostic(uri)
+
+ return {
+ kind = 'unchanged',
+ resultId = uri,
+ }
+
+ --if not params.previousResultId then
+ -- core.clearCache(uri)
+ --end
+ --local results, unchanged = core.pullDiagnostic(uri, false)
+ --if unchanged then
+ -- return {
+ -- kind = 'unchanged',
+ -- resultId = uri,
+ -- }
+ --else
+ -- return {
+ -- kind = 'full',
+ -- resultId = uri,
+ -- items = results or {},
+ -- }
+ --end
+ end
+}
+
+m.register 'workspace/diagnostic' {
+ --preview = true,
+ --capability = {
+ -- diagnosticProvider = {
+ -- workspaceDiagnostics = false,
+ -- }
+ --},
+ ---@async
+ function (params)
+ local core = require 'provider.diagnostic'
+ local excepts = {}
+ for _, id in ipairs(params.previousResultIds) do
+ excepts[#excepts+1] = id.value
+ end
+ core.clearCacheExcept(excepts)
+ local function convertItem(result)
+ if result.unchanged then
+ return {
+ kind = 'unchanged',
+ resultId = result.uri,
+ uri = result.uri,
+ version = result.version,
+ }
+ else
+ return {
+ kind = 'full',
+ resultId = result.uri,
+ items = result.result or {},
+ uri = result.uri,
+ version = result.version,
+ }
+ end
+ end
+ core.pullDiagnosticScope(function (result)
+ proto.notify('$/progress', {
+ token = params.partialResultToken,
+ value = {
+ items = {
+ convertItem(result)
+ }
+ }
+ })
+ end)
+ return { items = {} }
+ end
+}
+
+m.register '$/api/report' {
+ ---@async
+ function (params)
+ local buildMeta = require 'provider.build-meta'
+ local SDBMHash = require 'SDBMHash'
+ await.close 'api/report'
+ await.setID 'api/report'
+ local name = params.name or 'default'
+ local uri = workspace.getFirstScope().uri
+ local hash = uri and ('%08x'):format(SDBMHash():hash(uri))
+ local encoding = config.get(nil, 'Lua.runtime.fileEncoding')
+ local nameBuf = {}
+ nameBuf[#nameBuf+1] = name
+ nameBuf[#nameBuf+1] = hash
+ nameBuf[#nameBuf+1] = encoding
+ local fileDir = METAPATH .. '/' .. table.concat(nameBuf, ' ')
+ fs.create_directories(fs.path(fileDir))
+ buildMeta.build(fileDir, params)
+ client.setConfig {
+ {
+ key = 'Lua.workspace.library',
+ action = 'add',
+ value = fileDir,
+ uri = uri,
+ }
+ }
+ end
+}
+
local function refreshStatusBar()
local valid = config.get(nil, 'Lua.window.statusBar')
for _, scp in ipairs(workspace.folders) do
diff --git a/script/provider/spell.lua b/script/provider/spell.lua
new file mode 100644
index 00000000..6647bbad
--- /dev/null
+++ b/script/provider/spell.lua
@@ -0,0 +1,53 @@
+local suc, codeFormat = pcall(require, 'code_format')
+if not suc then
+ return
+end
+
+local fs = require 'bee.filesystem'
+local config = require 'config'
+local diagnostics = require 'provider.diagnostic'
+local pformatting = require 'provider.formatting'
+local util = require 'utility'
+
+local m = {}
+
+function m.loadDictionaryFromFile(filePath)
+ return codeFormat.spell_load_dictionary_from_path(filePath)
+end
+
+function m.loadDictionaryFromBuffer(buffer)
+ return codeFormat.spell_load_dictionary_from_buffer(buffer)
+end
+
+function m.addWord(word)
+ return codeFormat.spell_load_dictionary_from_buffer(word)
+end
+
+function m.spellCheck(uri, text)
+ if not m._dictionaryLoaded then
+ m.initDictionary()
+ m._dictionaryLoaded = true
+ end
+
+ local tempDict = config.get(uri, 'Lua.spell.dict')
+
+ return codeFormat.spell_analysis(uri, text, tempDict)
+end
+
+function m.getSpellSuggest(word)
+ local status, result = codeFormat.spell_suggest(word)
+ if status then
+ return result
+ end
+end
+
+function m.initDictionary()
+ local basicDictionary = fs.path(METAPATH) / "spell/dictionary.txt"
+ local luaDictionary = fs.path(METAPATH) / "spell/lua_dict.txt"
+
+ m.loadDictionaryFromFile(basicDictionary:string())
+ m.loadDictionaryFromFile(luaDictionary:string())
+ pformatting.updateNonStandardSymbols(config.get(nil, "Lua.runtime.nonstandardSymbol"))
+end
+
+return m
diff --git a/script/pub/pub.lua b/script/pub/pub.lua
index 47591ee6..1e9b6c8f 100644
--- a/script/pub/pub.lua
+++ b/script/pub/pub.lua
@@ -136,8 +136,6 @@ function m.task(name, params, callback)
end
--- 接收反馈
---- 返回接收到的反馈数量
----@return integer
function m.recieve(block)
if block then
local id, name, result = waiter:bpop()
diff --git a/script/service/service.lua b/script/service/service.lua
index 26790c63..07612d7b 100644
--- a/script/service/service.lua
+++ b/script/service/service.lua
@@ -235,7 +235,7 @@ end
function m.testVersion()
local stack = debug.setcstacklimit(200)
debug.setcstacklimit(stack + 1)
- if debug.setcstacklimit(stack) == stack + 1 then
+ if type(stack) == 'number' and debug.setcstacklimit(stack) == stack + 1 then
proto.notify('window/showMessage', {
type = 2,
message = 'It seems to be running in Lua 5.4.0 or Lua 5.4.1 . Please upgrade to Lua 5.4.2 or above. Otherwise, it may encounter weird "C stack overflow", resulting in failure to work properly',
diff --git a/script/service/telemetry.lua b/script/service/telemetry.lua
index 2e52def2..211dff0e 100644
--- a/script/service/telemetry.lua
+++ b/script/service/telemetry.lua
@@ -90,6 +90,7 @@ local function pushErrorLog(link)
))
end
+---@type boolean?
local isValid = false
timer.wait(5, function ()
diff --git a/script/utility.lua b/script/utility.lua
index 47b0c8d8..7f9d559b 100644
--- a/script/utility.lua
+++ b/script/utility.lua
@@ -8,6 +8,7 @@ local ipairs = ipairs
local next = next
local rawset = rawset
local move = table.move
+local tableRemove = table.remove
local setmetatable = debug.setmetatable
local mathType = math.type
local mathCeil = math.ceil
@@ -157,8 +158,10 @@ function m.dump(tbl, option)
local tp = type(value)
local format = option['format'] and option['format'][key]
if format then
- lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, format(value, unpack, deep+1, stack))
- elseif tp == 'table' then
+ value = format(value, unpack, deep+1, stack)
+ tp = type(value)
+ end
+ if tp == 'table' then
if mark[value] and mark[value] > 0 then
lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, option['loop'] or '"<Loop>"')
elseif deep >= (option['deep'] or mathHuge) then
@@ -272,7 +275,7 @@ local function sortTable(tbl)
end
--- 创建一个有序表
----@param tbl table {optional = 'self'}
+---@param tbl? table
---@return table
function m.container(tbl)
return sortTable(tbl)
@@ -287,6 +290,9 @@ function m.loadFile(path, keepBom)
end
local text = f:read 'a'
f:close()
+ if not text then
+ return nil
+ end
if not keepBom then
if text:sub(1, 3) == '\xEF\xBB\xBF' then
return text:sub(4)
@@ -521,6 +527,14 @@ function m.revertTable(t)
return t
end
+function m.revertMap(t)
+ local nt = {}
+ for k, v in pairs(t) do
+ nt[v] = k
+ end
+ return nt
+end
+
function m.randomSortTable(t, max)
local len = #t
if len <= 1 then
@@ -567,7 +581,7 @@ end
---遍历文本的每一行
---@param text string
---@param keepNL? boolean # 保留换行符
----@return fun(text:string):string, integer
+---@return fun():string, integer
function m.eachLine(text, keepNL)
local offset = 1
local lineCount = 0
@@ -649,9 +663,6 @@ function m.trim(str, mode)
end
function m.expandPath(path)
- if type(path) ~= 'string' then
- return nil
- end
if path:sub(1, 1) == '~' then
local home = getenv('HOME')
if not home then -- has to be Windows
@@ -791,8 +802,60 @@ function m.multiTable(count, default)
return current
end
+---@param t table
+---@param sorter boolean|function
+function m.getTableKeys(t, sorter)
+ local keys = {}
+ for k in pairs(t) do
+ keys[#keys+1] = k
+ end
+ if sorter == true then
+ tableSort(keys)
+ elseif type(sorter) == 'function' then
+ tableSort(keys, sorter)
+ end
+ return keys
+end
+
+function m.arrayHas(array, value)
+ for i = 1, #array do
+ if array[i] == value then
+ return true
+ end
+ end
+ return false
+end
+
+function m.arrayInsert(array, value)
+ if not m.arrayHas(array, value) then
+ array[#array+1] = value
+ end
+end
+
+function m.arrayRemove(array, value)
+ for i = 1, #array do
+ if array[i] == value then
+ tableRemove(array, i)
+ return
+ end
+ end
+end
+
m.MODE_K = { __mode = 'k' }
m.MODE_V = { __mode = 'v' }
m.MODE_KV = { __mode = 'kv' }
+---@generic T: fun(param: any):any
+---@param func T
+---@return T
+function m.cacheReturn(func)
+ local cache = {}
+ return function (param)
+ if cache[param] == nil then
+ cache[param] = func(param)
+ end
+ return cache[param]
+ end
+end
+
return m
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 75620d19..60c7243e 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -6,10 +6,149 @@ local files = require 'files'
---@class vm
local vm = require 'vm.vm'
+local LOCK = {}
+
---@class parser.object
---@field _compiledNodes boolean
---@field _node vm.node
---@field _globalBase table
+---@field cindex integer
+---@field func parser.object
+---@field operators? parser.object[]
+
+-- 该函数有副作用,会给source绑定node!
+---@param source parser.object
+---@return boolean
+local function bindDocs(source)
+ local docs = source.bindDocs
+ if not docs then
+ return false
+ end
+ for i = #docs, 1, -1 do
+ local doc = docs[i]
+ if doc.type == 'doc.type' then
+ vm.setNode(source, vm.compileNode(doc))
+ return true
+ end
+ if doc.type == 'doc.class' then
+ vm.setNode(source, vm.compileNode(doc))
+ return true
+ end
+ if doc.type == 'doc.param' then
+ local node = vm.compileNode(doc)
+ if doc.optional then
+ node:addOptional()
+ end
+ vm.setNode(source, node)
+ return true
+ end
+ if doc.type == 'doc.module' then
+ local name = doc.module
+ if not name then
+ return true
+ end
+ local uri = rpath.findUrisByRequireName(guide.getUri(source), name)[1]
+ if not uri then
+ return true
+ end
+ local state = files.getState(uri)
+ local ast = state and state.ast
+ if not ast then
+ return true
+ end
+ vm.setNode(source, vm.compileNode(ast))
+ return true
+ end
+ if doc.type == 'doc.overload' then
+ vm.setNode(source, vm.compileNode(doc))
+ end
+ end
+ return false
+end
+
+---@param source parser.object
+---@param key any
+---@param ref boolean
+---@param pushResult fun(res: parser.object, markDoc?: boolean)
+local function searchFieldByLocalID(source, key, ref, pushResult)
+ local fields
+ if key then
+ fields = vm.getLocalSourcesSets(source, key)
+ if ref then
+ local gets = vm.getLocalSourcesGets(source, key)
+ if gets then
+ fields = fields or {}
+ for _, src in ipairs(gets) do
+ fields[#fields+1] = src
+ end
+ end
+ end
+ else
+ fields = vm.getLocalFields(source, false)
+ end
+ if not fields then
+ return
+ end
+ local hasMarkDoc = {}
+ for _, src in ipairs(fields) do
+ if src.bindDocs then
+ if bindDocs(src) then
+ local skey = guide.getKeyName(src)
+ if skey then
+ hasMarkDoc[skey] = true
+ end
+ pushResult(src, true)
+ end
+ end
+ end
+ for _, src in ipairs(fields) do
+ local skey = guide.getKeyName(src)
+ if not hasMarkDoc[skey] then
+ pushResult(src)
+ end
+ end
+end
+
+---@param suri uri
+---@param source parser.object
+---@param key any
+---@param ref boolean
+---@param pushResult fun(res: parser.object, markDoc?: boolean)
+local function searchFieldByGlobalID(suri, source, key, ref, pushResult)
+ local node = source._globalNode
+ if not node then
+ return
+ end
+ if node.cate == 'variable' then
+ if key then
+ if type(key) ~= 'string' then
+ return
+ end
+ local global = vm.getGlobal('variable', node.name, key)
+ if global then
+ for _, set in ipairs(global:getSets(suri)) do
+ pushResult(set)
+ end
+ for _, get in ipairs(global:getGets(suri)) do
+ pushResult(get)
+ end
+ end
+ else
+ local globals = vm.getGlobalFields('variable', node.name)
+ for _, global in ipairs(globals) do
+ for _, set in ipairs(global:getSets(suri)) do
+ pushResult(set)
+ end
+ for _, get in ipairs(global:getGets(suri)) do
+ pushResult(get)
+ end
+ end
+ end
+ end
+ if node.cate == 'type' then
+ vm.getClassFields(suri, node, key, ref, pushResult)
+ end
+end
local searchFieldSwitch = util.switch()
: case 'table'
@@ -47,6 +186,7 @@ local searchFieldSwitch = util.switch()
end
end)
: case 'string'
+ : case 'doc.type.string'
: call(function (suri, source, key, ref, pushResult)
-- change to `string: stringlib` ?
local stringlib = vm.getGlobal('type', 'stringlib')
@@ -54,23 +194,6 @@ local searchFieldSwitch = util.switch()
vm.getClassFields(suri, stringlib, key, ref, pushResult)
end
end)
- : case 'local'
- : case 'self'
- : call(function (suri, node, key, ref, pushResult)
- local fields
- if key then
- fields = vm.getLocalSources(node, key)
- else
- fields = vm.getLocalFields(node)
- end
- if fields then
- for _, src in ipairs(fields) do
- if ref or guide.isSet(src) then
- pushResult(src)
- end
- end
- end
- end)
: case 'doc.type.array'
: call(function (suri, source, key, ref, pushResult)
if type(key) == 'number' then
@@ -78,8 +201,13 @@ local searchFieldSwitch = util.switch()
or not math.tointeger(key) then
return
end
+ pushResult(source.node)
+ end
+ if type(key) == 'table' then
+ if vm.isSubType(suri, key, 'integer') then
+ pushResult(source.node)
+ end
end
- pushResult(source.node)
end)
: case 'doc.type.table'
: call(function (suri, source, key, ref, pushResult)
@@ -144,42 +272,15 @@ local searchFieldSwitch = util.switch()
end
end)
: default(function (suri, source, key, ref, pushResult)
- local node = source._globalNode
- if not node then
- return
- end
- if node.cate == 'variable' then
- if key then
- if type(key) ~= 'string' then
- return
- end
- local global = vm.getGlobal('variable', node.name, key)
- if global then
- for _, set in ipairs(global:getSets(suri)) do
- pushResult(set)
- end
- for _, get in ipairs(global:getGets(suri)) do
- pushResult(get)
- end
- end
- else
- local globals = vm.getGlobalFields('variable', node.name)
- for _, global in ipairs(globals) do
- for _, set in ipairs(global:getSets(suri)) do
- pushResult(set)
- end
- for _, get in ipairs(global:getGets(suri)) do
- pushResult(get)
- end
- end
- end
- end
- if node.cate == 'type' then
- vm.getClassFields(suri, node, key, ref, pushResult)
- end
+ searchFieldByLocalID(source, key, ref, pushResult)
+ searchFieldByGlobalID(suri, source, key, ref, pushResult)
end)
-
+---@param suri uri
+---@param object vm.global
+---@param key string|vm.global
+---@param ref boolean
+---@param pushResult fun(field: vm.object, isMark?: boolean)
function vm.getClassFields(suri, object, key, ref, pushResult)
local mark = {}
@@ -194,6 +295,14 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
if set.type == 'doc.class' then
-- check ---@field
local hasFounded = {}
+
+ local function copyToSearched()
+ for fieldKey in pairs(hasFounded) do
+ searchedFields[fieldKey] = true
+ hasFounded[fieldKey] = nil
+ end
+ end
+
for _, field in ipairs(set.fields) do
local fieldKey = guide.getKeyName(field)
if fieldKey then
@@ -201,12 +310,11 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
if key == nil
or fieldKey == key then
if not searchedFields[fieldKey] then
- pushResult(field)
+ pushResult(field, true)
hasFounded[fieldKey] = true
end
end
- end
- if not hasFounded[fieldKey] then
+ elseif key and not hasFounded[key] then
local keyType = type(key)
if keyType == 'table' then
-- ---@field [integer] boolean -> class[integer]
@@ -214,7 +322,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
if vm.isSubType(suri, key.name, fieldNode) then
local nkey = '|' .. key.name
if not searchedFields[nkey] then
- pushResult(field)
+ pushResult(field, true)
hasFounded[nkey] = true
end
end
@@ -230,13 +338,13 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
or keyType == 'string' then
typeName = keyType
end
- if typeName then
+ if typeName and field.field.type ~= 'doc.field.name' then
-- ---@field [integer] boolean -> class[1]
local fieldNode = vm.compileNode(field.field)
if vm.isSubType(suri, typeName, fieldNode) then
local nkey = '|' .. typeName
if not searchedFields[nkey] then
- pushResult(field)
+ pushResult(field, true)
hasFounded[nkey] = true
end
end
@@ -244,38 +352,37 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
end
end
end
+ copyToSearched()
-- check local field and global field
- if set.bindSources then
- for _, src in ipairs(set.bindSources) do
- searchFieldSwitch(src.type, suri, src, key, ref, function (field)
+ if not searchedFields[key] and set.bindSource then
+ local src = set.bindSource
+ if src.value and src.value.type == 'table' then
+ searchFieldSwitch('table', suri, src.value, key, ref, function (field)
local fieldKey = guide.getKeyName(field)
if fieldKey then
if not searchedFields[fieldKey]
and guide.isSet(field) then
hasFounded[fieldKey] = true
- pushResult(field)
+ pushResult(field, true)
end
end
end)
- if src.value and src.value.type == 'table' then
- searchFieldSwitch('table', suri, src.value, key, ref, function (field)
- local fieldKey = guide.getKeyName(field)
- if fieldKey then
- if not searchedFields[fieldKey]
- and guide.isSet(field) then
- hasFounded[fieldKey] = true
- pushResult(field)
- end
- end
- end)
- end
end
+ copyToSearched()
+ searchFieldSwitch(src.type, suri, src, key, ref, function (field)
+ local fieldKey = guide.getKeyName(field)
+ if fieldKey and not searchedFields[fieldKey] then
+ if not searchedFields[fieldKey]
+ and guide.isSet(field) then
+ hasFounded[fieldKey] = true
+ pushResult(field, true)
+ end
+ end
+ end)
+ copyToSearched()
end
-- look into extends(if field not found)
- if not hasFounded[key] and set.extends then
- for fieldKey in pairs(hasFounded) do
- searchedFields[fieldKey] = true
- end
+ if not searchedFields[key] and set.extends then
for _, extend in ipairs(set.extends) do
if extend.type == 'doc.extends.name' then
local extendType = vm.getGlobal('type', extend[1])
@@ -284,6 +391,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
end
end
end
+ copyToSearched()
end
end
end
@@ -296,7 +404,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
for _, set in ipairs(sets) do
pushResult(set)
end
- else
+ elseif type(key) == 'string' then
local global = vm.getGlobal('variable', key)
if global then
for _, set in ipairs(global:getSets(suri)) do
@@ -312,10 +420,10 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
end
---@class parser.object
----@field _sign? vm.sign
+---@field _sign vm.sign|false
---@param source parser.object
----@return vm.sign?
+---@return vm.sign|false
local function getObjectSign(source)
if source._sign ~= nil then
return source._sign
@@ -376,7 +484,7 @@ end
---@param func parser.object
---@param index integer
----@return vm.object?
+---@return (parser.object|vm.generic)?
function vm.getReturnOfFunction(func, index)
if func.type == 'function' then
if not func._returns then
@@ -384,9 +492,9 @@ function vm.getReturnOfFunction(func, index)
end
if not func._returns[index] then
func._returns[index] = {
- type = 'function.return',
- parent = func,
- index = index,
+ type = 'function.return',
+ parent = func,
+ returnIndex = index,
}
end
return func._returns[index]
@@ -394,7 +502,12 @@ function vm.getReturnOfFunction(func, index)
if func.type == 'doc.type.function' then
local rtn = func.returns[index]
if not rtn then
- return nil
+ local lastReturn = func.returns[#func.returns]
+ if lastReturn and lastReturn.name and lastReturn.name[1] == '...' then
+ rtn = lastReturn
+ else
+ return nil
+ end
end
local sign = getObjectSign(func)
if not sign then
@@ -402,8 +515,10 @@ function vm.getReturnOfFunction(func, index)
end
return vm.createGeneric(rtn, sign)
end
+ return nil
end
+---@param args parser.object[]
---@return vm.node
local function getReturnOfSetMetaTable(args)
local tbl = args[1]
@@ -427,88 +542,60 @@ local function getReturnOfSetMetaTable(args)
return node
end
----@return vm.node?
-local function getReturn(func, index, args)
- if func.special == 'setmetatable' then
- if not args then
- return nil
- end
- return getReturnOfSetMetaTable(args)
- end
- if func.special == 'pcall' and index > 1 then
- if not args then
- return nil
- end
- local newArgs = {}
- for i = 2, #args do
- newArgs[#newArgs+1] = args[i]
- end
- return getReturn(args[1], index - 1, newArgs)
- end
- if func.special == 'xpcall' and index > 1 then
- if not args then
- return nil
- end
- local newArgs = {}
- for i = 3, #args do
- newArgs[#newArgs+1] = args[i]
- end
- return getReturn(args[1], index - 1, newArgs)
+---@param source parser.object
+local function matchCall(source)
+ local call = source.parent
+ if not call
+ or call.type ~= 'call'
+ or call.node ~= source then
+ return
end
- if func.special == 'require' then
- if not args then
- return nil
- end
- local nameArg = args[1]
- if not nameArg or nameArg.type ~= 'string' then
- return nil
- end
- local name = nameArg[1]
- if not name or type(name) ~= 'string' then
- return nil
- end
- local uri = rpath.findUrisByRequirePath(guide.getUri(func), name)[1]
- if not uri then
- return nil
- end
- local state = files.getState(uri)
- local ast = state and state.ast
- if not ast then
- return nil
- end
- return vm.compileNode(ast)
+ local funcs = vm.getMatchedFunctions(source, call.args)
+ local myNode = vm.getNode(source)
+ if not myNode then
+ return
end
- local node = vm.compileNode(func)
- ---@type vm.node?
- local result
- for cnode in node:eachObject() do
- if cnode.type == 'function'
- or cnode.type == 'doc.type.function' then
- local returnObject = vm.getReturnOfFunction(cnode, index)
- if returnObject then
- local returnNode = vm.compileNode(returnObject)
- for rnode in returnNode:eachObject() do
- if rnode.type == 'generic' then
- returnNode = rnode:resolve(guide.getUri(func), args)
- break
- end
- end
- if returnNode then
- for rnode in returnNode:eachObject() do
- -- TODO: narrow type
- if rnode.type ~= 'doc.generic.name' then
- result = result or vm.createNode()
- result:merge(rnode)
- end
- end
- if result and returnNode:isOptional() then
- result:addOptional()
- end
+ local needRemove
+ for n in myNode:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ if not util.arrayHas(funcs, n) then
+ if not needRemove then
+ needRemove = vm.createNode()
end
+ needRemove:merge(n)
end
end
end
- return result
+ if needRemove then
+ local newNode = myNode:copy()
+ newNode:removeNode(needRemove)
+ newNode:setData('originNode', myNode)
+ vm.setNode(source, newNode, true)
+ end
+end
+
+---@param func parser.object
+---@param index integer
+---@param args parser.object[]
+---@return vm.node
+local function getReturn(func, index, args)
+ if not func._callReturns then
+ func._callReturns = {}
+ end
+ if not func._callReturns[index] then
+ local call = func.parent
+ func._callReturns[index] = {
+ type = 'call.return',
+ parent = call,
+ func = func,
+ cindex = index,
+ args = args,
+ start = call.start,
+ finish = call.finish,
+ }
+ end
+ return vm.compileNode(func._callReturns[index])
end
---@param source parser.object
@@ -517,123 +604,136 @@ local function bindAs(source)
local root = guide.getRoot(source)
local docs = root.docs
if not docs then
- return
+ return false
end
- for _, doc in ipairs(docs) do
- if doc.type == 'doc.as' and doc.originalComment.start == source.finish + 2 then
- if doc.as then
- vm.setNode(source, vm.compileNode(doc.as), true)
+ local ases = docs._asCache
+ if not ases then
+ ases = {}
+ docs._asCache = ases
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.as' and doc.as and doc.touch then
+ ases[#ases+1] = doc
end
- return true
end
+ table.sort(ases, function (a, b)
+ return a.touch < b.touch
+ end)
end
- return false
-end
-local function bindDocs(source)
- local isParam = source.parent.type == 'funcargs'
- or source.parent.type == 'in'
- local docs = source.bindDocs
- for i = #docs, 1, -1 do
- local doc = docs[i]
- if doc.type == 'doc.type' then
- if not isParam then
- vm.setNode(source, vm.compileNode(doc))
- return true
- end
- end
- if doc.type == 'doc.class' then
- if (source.type == 'local' and not isParam)
- or (source._globalNode and guide.isSet(source))
- or source.type == 'tablefield'
- or source.type == 'tableindex' then
- vm.setNode(source, vm.compileNode(doc))
- return true
- end
- end
- if doc.type == 'doc.param' then
- if isParam and source[1] == doc.param[1] then
- local node = vm.compileNode(doc)
- if doc.optional then
- node:addOptional()
- end
- vm.setNode(source, node)
- return true
- end
- end
- if doc.type == 'doc.module' then
- local name = doc.module
- local uri = rpath.findUrisByRequirePath(guide.getUri(source), name)[1]
- if not uri then
- return nil
- end
- local state = files.getState(uri)
- local ast = state and state.ast
- if not ast then
- return nil
- end
- vm.setNode(source, vm.compileNode(ast))
- return true
- end
- if doc.type == 'doc.overload' then
- if not isParam then
- vm.setNode(source, vm.compileNode(doc))
- end
- end
+ if #ases == 0 then
+ return false
end
- return false
-end
-local function compileByLocalID(source)
- local sources = vm.getLocalSources(source)
- if not sources then
- return
- end
- local hasMarkDoc
- for _, src in ipairs(sources) do
- if src.bindDocs then
- if bindDocs(src) then
- hasMarkDoc = true
- vm.setNode(source, vm.compileNode(src))
- end
+ local max = #ases
+ local index
+ local left = 1
+ local right = max
+ for _ = 1, 1000 do
+ if left == right then
+ index = left
+ break
+ end
+ index = left + (right - left) // 2
+ local doc = ases[index]
+ if doc.touch < source.finish then
+ left = index + 1
+ else
+ right = index
end
end
- for _, src in ipairs(sources) do
- if src.value then
- if not hasMarkDoc or guide.isLiteral(src.value) then
- if src.value.type ~= 'nil' then
- vm.setNode(source, vm.compileNode(src.value))
- end
- end
- end
+
+ local doc = ases[index]
+ if doc and doc.touch == source.finish then
+ vm.setNode(source, vm.compileNode(doc.as), true)
+ return true
end
+
+ return false
end
----@param source vm.node
----@param key? any
+---@param source parser.object
+---@param key? string|vm.global
---@param pushResult fun(source: parser.object)
function vm.compileByParentNode(source, key, ref, pushResult)
local parentNode = vm.compileNode(source)
+ local docedResults = {}
+ local commonResults = {}
+ local mark = {}
local suri = guide.getUri(source)
+ local hasClass
for node in parentNode:eachObject() do
- searchFieldSwitch(node.type, suri, node, key, ref, pushResult)
+ if node.type == 'global'
+ and node.cate == 'type'
+ ---@cast node vm.global
+ and not guide.isBasicType(node.name) then
+ hasClass = true
+ break
+ end
+ end
+ for node in parentNode:eachObject() do
+ if not hasClass
+ or (
+ node.type == 'global'
+ and node.cate == 'type'
+ ---@cast node vm.global
+ and not guide.isBasicType(node.name)
+ )
+ or guide.isLiteral(node) then
+ searchFieldSwitch(node.type, suri, node, key, ref, function (res, markDoc)
+ if mark[res] then
+ return
+ end
+ mark[res] = true
+ if markDoc then
+ docedResults[#docedResults+1] = res
+ else
+ commonResults[#commonResults+1] = res
+ end
+ end)
+ end
end
-end
----@return vm.node?
-local function selectNode(source, list, index)
- if not list then
- return nil
+ if not next(mark) then
+ searchFieldByLocalID(source, key, ref, function (res, markDoc)
+ if mark[res] then
+ return
+ end
+ mark[res] = true
+ if markDoc then
+ docedResults[#docedResults+1] = res
+ else
+ commonResults[#commonResults+1] = res
+ end
+ end)
end
+
+ if #docedResults > 0 then
+ for _, res in ipairs(docedResults) do
+ pushResult(res)
+ end
+ end
+ if #docedResults == 0 or key == nil then
+ for _, res in ipairs(commonResults) do
+ pushResult(res)
+ end
+ end
+end
+
+---@param list parser.object[]
+---@param index integer
+---@return vm.node
+---@return parser.object?
+function vm.selectNode(list, index)
local exp
if list[index] then
exp = list[index]
+ index = 1
else
for i = index, 1, -1 do
if list[i] then
local last = list[i]
if last.type == 'call'
- or last.type == '...' then
+ or last.type == 'varargs' then
index = index - i + 1
exp = last
end
@@ -642,40 +742,46 @@ local function selectNode(source, list, index)
end
end
if not exp then
- return nil
+ return vm.createNode(vm.declareGlobal('type', 'nil')), nil
end
+ ---@type vm.node?
local result
if exp.type == 'call' then
result = getReturn(exp.node, index, exp.args)
- if not result then
- vm.setNode(source, vm.declareGlobal('type', 'unknown'))
- return vm.getNode(source)
+ if result:isEmpty() then
+ result:merge(vm.declareGlobal('type', 'unknown'))
end
else
+ ---@type vm.node
result = vm.compileNode(exp)
+ if result:isEmpty() then
+ result:merge(vm.declareGlobal('type', 'unknown'))
+ end
end
+ return result, exp
+end
+
+---@param source parser.object
+---@param list parser.object[]
+---@param index integer
+---@return vm.node
+local function selectNode(source, list, index)
+ local result = vm.selectNode(list, index)
if source.type == 'function.return' then
-- remove any for returns
local rtnNode = vm.createNode()
- local hasKnownType
for n in result:eachObject() do
if guide.isLiteral(n) then
- hasKnownType = true
rtnNode:merge(n)
end
if n.type == 'global' and n.cate == 'type' then
- if n.name ~= 'any'
- and n.name ~= 'unknown' then
- hasKnownType = true
+ if n.name ~= 'any' then
rtnNode:merge(n)
end
else
rtnNode:merge(n)
end
end
- if not hasKnownType then
- rtnNode:merge(vm.declareGlobal('type', 'unknown'))
- end
vm.setNode(source, rtnNode)
return rtnNode
end
@@ -684,15 +790,19 @@ local function selectNode(source, list, index)
end
---@param source parser.object
----@param node vm.object
+---@param node vm.node.object
---@return boolean
local function isValidCallArgNode(source, node)
if source.type == 'function' then
return node.type == 'doc.type.function'
end
if source.type == 'table' then
- return node.type == 'doc.type.table'
- or (node.type == 'global' and node.cate == 'type' and not guide.isBasicType(node.name))
+ return node.type == 'doc.type.table' or node.type == 'doc.type.array'
+ or ( node.type == 'global'
+ and node.cate == 'type'
+ ---@cast node vm.global
+ and not guide.isBasicType(node.name)
+ )
end
if source.type == 'dummyarg' then
return true
@@ -741,12 +851,14 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
for n in callNode:eachObject() do
if n.type == 'function' then
+ ---@cast n parser.object
local sign = getObjectSign(n)
local farg = getFuncArg(n, myIndex)
if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
if fn.type == 'doc.type.function' then
+ ---@cast fn parser.object
if sign then
local generic = vm.createGeneric(fn, sign)
local args = {}
@@ -762,6 +874,7 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
end
end
if n.type == 'doc.type.function' then
+ ---@cast n parser.object
local myEvent
if n.args[eventIndex] then
local argNode = vm.compileNode(n.args[eventIndex])
@@ -788,6 +901,7 @@ end
---@param arg parser.object
---@param call parser.object
---@param index? integer
+---@return vm.node?
function vm.compileCallArg(arg, call, index)
if not index then
for i, carg in ipairs(call.args) do
@@ -796,6 +910,9 @@ function vm.compileCallArg(arg, call, index)
break
end
end
+ if not index then
+ return nil
+ end
end
local callNode = vm.compileNode(call.node)
@@ -812,8 +929,44 @@ function vm.compileCallArg(arg, call, index)
return vm.getNode(arg)
end
+---@class parser.object
+---@field _iterator? table
+---@field _iterArgs? table
+---@field _iterVars? table<parser.object, vm.node>
+
+---@param source parser.object
+local function compileForVars(source)
+ if source._iterator then
+ return
+ end
+ if not source.exps then
+ return
+ end
+ -- for k, v in pairs(t) do
+ --> for k, v in iterator, status, initValue do
+ --> local k, v = iterator(status, initValue)
+ source._iterator = {
+ type = 'dummyfunc',
+ parent = source,
+ }
+ source._iterArgs = {{},{}}
+ source._iterVars = {}
+ -- iterator
+ selectNode(source._iterator, source.exps, 1)
+ -- status
+ selectNode(source._iterArgs[1], source.exps, 2)
+ -- initValue
+ selectNode(source._iterArgs[2], source.exps, 3)
+ if source.keys then
+ for i, loc in ipairs(source.keys) do
+ local node = getReturn(source._iterator, i, source._iterArgs)
+ node:removeOptional()
+ source._iterVars[loc] = node
+ end
+ end
+end
+
---@param source parser.object
----@return vm.node
local function compileLocal(source)
vm.setNode(source, source)
@@ -835,15 +988,25 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(source.parent.parent.parent.node))
end
end
+ vm.getNode(source):remove 'function'
end
local hasMarkValue
- if source.value then
- if not hasMarkDoc or guide.isLiteral(source.value) then
- hasMarkValue = true
- if source.value.type == 'table' then
- vm.setNode(source, source.value)
- elseif source.value.type ~= 'nil' then
- vm.setNode(source, vm.compileNode(source.value))
+ if not hasMarkDoc and source.value then
+ hasMarkValue = true
+ if source.value.type == 'table' then
+ vm.setNode(source, source.value)
+ elseif source.value.type ~= 'nil' then
+ vm.setNode(source, vm.compileNode(source.value))
+ end
+ end
+ if not hasMarkValue and not hasMarkValue then
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ if ref.type == 'setlocal'
+ and ref.value
+ and ref.value.type == 'function' then
+ vm.setNode(source, vm.compileNode(ref.value))
+ end
end
end
end
@@ -883,12 +1046,21 @@ local function compileLocal(source)
end
-- for x in ... do
if source.parent.type == 'in' then
- vm.compileNode(source.parent)
+ compileForVars(source.parent)
+ local keyNode = source.parent._iterVars and source.parent._iterVars[source]
+ if keyNode then
+ vm.setNode(source, keyNode)
+ end
end
-- for x = ... do
if source.parent.type == 'loop' then
- vm.compileNode(source.parent)
+ if source.parent.loc == source then
+ if bindDocs(source) then
+ return
+ end
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
+ end
end
vm.getNode(source):setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue)
@@ -920,9 +1092,21 @@ local compilerSwitch = util.switch()
or source.parent.type == 'setlocal'
or source.parent.type == 'tablefield'
or source.parent.type == 'tableindex'
+ or source.parent.type == 'tableexp'
or source.parent.type == 'setfield'
or source.parent.type == 'setindex' then
- vm.setNode(source, vm.compileNode(source.parent))
+ local parentNode = vm.compileNode(source.parent)
+ for _, pn in ipairs(parentNode) do
+ if pn.type == 'global'
+ and pn.cate == 'type' then
+ ---@cast pn vm.global
+ if not guide.isBasicType(pn.name) then
+ vm.setNode(source, pn)
+ end
+ elseif pn.type == 'doc.type.table' or pn.type == 'doc.type.array' then
+ vm.setNode(source, pn)
+ end
+ end
end
end)
: case 'function'
@@ -964,8 +1148,7 @@ local compilerSwitch = util.switch()
local hasMark = vm.getNode(source):getData 'hasDefined'
- local runner = vm.createRunner(source)
- runner:launch(function (src, node)
+ vm.launchRunner(source, function (src, node)
if src.type == 'setlocal' then
if src.bindDocs then
for _, doc in ipairs(src.bindDocs) do
@@ -975,18 +1158,32 @@ local compilerSwitch = util.switch()
end
end
end
- if src.value and guide.isLiteral(src.value) then
+ if src.value then
if src.value.type == 'table' then
- vm.setNode(src, vm.createNode(src.value), true)
+ vm.setNode(src, vm.createNode(src.value))
+ vm.setNode(src, node:copy():asTable())
else
+ local function clearLockedNode(child)
+ if not child then
+ return
+ end
+ if child.type == 'function' then
+ return
+ end
+ if child.type == 'setlocal'
+ or child.type == 'getlocal' then
+ if child.node == source then
+ return
+ end
+ end
+ if LOCK[child] then
+ vm.removeNode(child)
+ end
+ guide.eachChild(child, clearLockedNode)
+ end
+ clearLockedNode(src.value)
vm.setNode(src, vm.compileNode(src.value), true)
end
- elseif src.value
- and src.value.type == 'binary'
- and src.value.op and src.value.op.type == 'or'
- and src.value[1] and src.value[1].type == 'getlocal' and src.value[1].node == source then
- -- x = x or 1
- vm.setNode(src, vm.compileNode(src.value))
else
vm.setNode(src, node, true)
end
@@ -996,6 +1193,7 @@ local compilerSwitch = util.switch()
return
end
vm.setNode(src, node, true)
+ matchCall(src)
end
end)
@@ -1004,7 +1202,10 @@ local compilerSwitch = util.switch()
for _, ref in ipairs(source.ref) do
if ref.type == 'setlocal'
and guide.getParentFunction(ref) == parentFunc then
- vm.setNode(source, vm.getNode(ref))
+ local refNode = vm.getNode(ref)
+ if refNode then
+ vm.setNode(source, refNode)
+ end
end
end
end
@@ -1023,27 +1224,14 @@ local compilerSwitch = util.switch()
: case 'setfield'
: case 'setmethod'
: case 'setindex'
- : call(function (source)
- compileByLocalID(source)
- local key = guide.getKeyName(source)
- if key == nil then
- return
- end
- vm.compileByParentNode(source.node, key, false, function (src)
- if src.type == 'doc.type.field'
- or src.type == 'doc.field' then
- vm.setNode(source, vm.compileNode(src))
- end
- end)
- end)
: case 'getfield'
: case 'getmethod'
: case 'getindex'
: call(function (source)
- if bindAs(source) then
+ if guide.isGet(source) and bindAs(source) then
return
end
- compileByLocalID(source)
+ ---@type (string|vm.node)?
local key = guide.getKeyName(source)
if key == nil and source.index then
key = vm.compileNode(source.index)
@@ -1052,14 +1240,39 @@ local compilerSwitch = util.switch()
return
end
if type(key) == 'table' then
+ ---@cast key vm.node
local uri = guide.getUri(source)
local value = vm.getTableValue(uri, vm.compileNode(source.node), key)
if value then
vm.setNode(source, value)
end
+ for k in key:eachObject() do
+ if k.type == 'global' and k.cate == 'type' then
+ ---@cast k vm.global
+ vm.compileByParentNode(source.node, k, false, function (src)
+ vm.setNode(source, vm.compileNode(src))
+ if src.value then
+ vm.setNode(source, vm.compileNode(src.value))
+ end
+ end)
+ end
+ end
else
+ ---@cast key string
vm.compileByParentNode(source.node, key, false, function (src)
- vm.setNode(source, vm.compileNode(src))
+ if src.value then
+ if bindDocs(src) then
+ vm.setNode(source, vm.compileNode(src))
+ elseif src.value.type ~= 'nil' then
+ vm.setNode(source, vm.compileNode(src.value))
+ local node = vm.getNode(src)
+ if node then
+ vm.setNode(source, node)
+ end
+ end
+ else
+ vm.setNode(source, vm.compileNode(src))
+ end
end)
end
end)
@@ -1097,21 +1310,20 @@ local compilerSwitch = util.switch()
hasMarkDoc = bindDocs(source)
end
- if source.value then
- if not hasMarkDoc or guide.isLiteral(source.value) then
- if source.value.type == 'table' then
- vm.setNode(source, source.value)
- elseif source.value.type ~= 'nil' then
- vm.setNode(source, vm.compileNode(source.value))
+ if not hasMarkDoc then
+ vm.compileByParentNode(source.node, guide.getKeyName(source), false, function (src)
+ if src.type == 'doc.field'
+ or src.type == 'doc.type.field' then
+ hasMarkDoc = true
+ vm.setNode(source, vm.compileNode(src))
end
- end
+ end)
end
- if not hasMarkDoc then
- vm.compileByParentNode(source.parent, guide.getKeyName(source), false, function (src)
- vm.setNode(source, vm.compileNode(src))
- end)
+ if not hasMarkDoc and source.value then
+ vm.setNode(source, vm.compileNode(source.value))
end
+
end)
: case 'field'
: case 'method'
@@ -1120,18 +1332,29 @@ local compilerSwitch = util.switch()
end)
: case 'tableexp'
: call(function (source)
+ if (source.parent.type == 'table') then
+ local node = vm.compileNode(source.parent)
+ for n in node:eachObject() do
+ if n.type == 'doc.type.array' then
+ vm.setNode(source, vm.compileNode(n.node))
+ end
+ end
+ end
vm.setNode(source, vm.compileNode(source.value))
end)
: case 'function.return'
+ ---@param source parser.object
: call(function (source)
local func = source.parent
- local index = source.index
+ local index = source.returnIndex
local hasMarkDoc
if func.bindDocs then
local sign = getObjectSign(func)
+ local lastReturn
for _, doc in ipairs(func.bindDocs) do
if doc.type == 'doc.return' then
for _, rtn in ipairs(doc.returns) do
+ lastReturn = rtn
if rtn.returnIndex == index then
hasMarkDoc = true
local hasGeneric
@@ -1141,6 +1364,7 @@ local compilerSwitch = util.switch()
end)
end
if hasGeneric then
+ ---@cast sign -false
vm.setNode(source, vm.createGeneric(rtn, sign))
else
vm.setNode(source, vm.compileNode(rtn))
@@ -1149,10 +1373,146 @@ local compilerSwitch = util.switch()
end
end
end
+ if lastReturn
+ and not hasMarkDoc then
+ if lastReturn.name and lastReturn.name[1] == '...' then
+ hasMarkDoc = true
+ vm.setNode(source, vm.compileNode(lastReturn))
+ end
+ end
end
+ local hasReturn
if func.returns and not hasMarkDoc then
for _, rtn in ipairs(func.returns) do
- selectNode(source, rtn, index)
+ if selectNode(source, rtn, index) then
+ hasReturn = true
+ end
+ end
+ if hasReturn then
+ local hasKnownType
+ local hasUnknownType
+ for n in vm.getNode(source):eachObject() do
+ if guide.isLiteral(n) then
+ if n.type ~= 'nil' then
+ hasKnownType = true
+ break
+ end
+ goto CONTINUE
+ end
+ if n.type == 'global' and n.cate == 'type' then
+ if n.name ~= 'nil' then
+ hasKnownType = true
+ break
+ end
+ goto CONTINUE
+ end
+ hasUnknownType = true
+ ::CONTINUE::
+ end
+ if not hasKnownType and hasUnknownType then
+ vm.setNode(source, vm.declareGlobal('type', 'unknown'))
+ end
+ end
+ end
+ if not hasMarkDoc and not hasReturn then
+ vm.setNode(source, vm.declareGlobal('type', 'nil'))
+ end
+ end)
+ : case 'call.return'
+ ---@param source parser.object
+ : call(function (source)
+ if bindAs(source) then
+ return
+ end
+ local func = source.func
+ local args = source.args
+ local index = source.cindex
+ if func.special == 'setmetatable' then
+ if not args then
+ return
+ end
+ vm.setNode(source, getReturnOfSetMetaTable(args))
+ return
+ end
+ if func.special == 'pcall' and index > 1 then
+ if not args then
+ return
+ end
+ local newArgs = {}
+ for i = 2, #args do
+ newArgs[#newArgs+1] = args[i]
+ end
+ local node = getReturn(args[1], index - 1, newArgs)
+ if node then
+ vm.setNode(source, node)
+ end
+ return
+ end
+ if func.special == 'xpcall' and index > 1 then
+ if not args then
+ return
+ end
+ local newArgs = {}
+ for i = 3, #args do
+ newArgs[#newArgs+1] = args[i]
+ end
+ local node = getReturn(args[1], index - 1, newArgs)
+ if node then
+ vm.setNode(source, node)
+ end
+ return
+ end
+ if func.special == 'require' then
+ if not args then
+ return
+ end
+ local nameArg = args[1]
+ if not nameArg or nameArg.type ~= 'string' then
+ return
+ end
+ local name = nameArg[1]
+ if not name or type(name) ~= 'string' then
+ return
+ end
+ local uri = rpath.findUrisByRequireName(guide.getUri(func), name)[1]
+ if not uri then
+ return
+ end
+ local state = files.getState(uri)
+ local ast = state and state.ast
+ if not ast then
+ return
+ end
+ vm.setNode(source, vm.compileNode(ast))
+ return
+ end
+ local funcNode = vm.compileNode(func)
+ ---@type vm.node?
+ for mfunc in funcNode:eachObject() do
+ if mfunc.type == 'function'
+ or mfunc.type == 'doc.type.function' then
+ ---@cast mfunc parser.object
+ local returnObject = vm.getReturnOfFunction(mfunc, index)
+ if returnObject then
+ local returnNode = vm.compileNode(returnObject)
+ for rnode in returnNode:eachObject() do
+ if rnode.type == 'generic' then
+ returnNode = rnode:resolve(guide.getUri(func), args)
+ break
+ end
+ end
+ if returnNode then
+ for rnode in returnNode:eachObject() do
+ -- TODO: narrow type
+ if rnode.type ~= 'doc.generic.name' then
+ vm.setNode(source, rnode)
+ end
+ end
+ if returnNode:isOptional() then
+ vm.getNode(source):addOptional()
+ end
+ end
+ end
end
end
end)
@@ -1174,12 +1534,8 @@ local compilerSwitch = util.switch()
if not node then
return
end
- for n in node:eachObject() do
- if n.type == 'global'
- and n.cate == 'type'
- and n.name == '...' then
- return
- end
+ if node:isEmpty() then
+ node = vm.runOperator('call', vararg.node) or node
end
vm.setNode(source, node)
end
@@ -1199,51 +1555,11 @@ local compilerSwitch = util.switch()
if not node then
return
end
- for n in node:eachObject() do
- if n.type == 'global'
- and n.cate == 'type'
- and n.name == '...' then
- return
- end
+ if node:isEmpty() then
+ node = vm.runOperator('call', source.node) or node
end
vm.setNode(source, node)
end)
- : case 'in'
- : call(function (source)
- if not source._iterator then
- -- for k, v in pairs(t) do
- --> for k, v in iterator, status, initValue do
- --> local k, v = iterator(status, initValue)
- source._iterator = {
- type = 'dummyfunc',
- parent = source,
- }
- source._iterArgs = {{},{}}
- end
- -- iterator
- selectNode(source._iterator, source.exps, 1)
- -- status
- selectNode(source._iterArgs[1], source.exps, 2)
- -- initValue
- selectNode(source._iterArgs[2], source.exps, 3)
- if source.keys then
- for i, loc in ipairs(source.keys) do
- local node = getReturn(source._iterator, i, source._iterArgs)
- if node then
- if i == 1 then
- node:removeOptional()
- end
- vm.setNode(loc, node)
- end
- end
- end
- end)
- : case 'loop'
- : call(function (source)
- if source.loc then
- vm.setNode(source.loc, vm.declareGlobal('type', 'integer'))
- end
- end)
: case 'doc.type'
: call(function (source)
for _, typeUnit in ipairs(source.types) do
@@ -1256,6 +1572,7 @@ local compilerSwitch = util.switch()
: case 'doc.type.integer'
: case 'doc.type.string'
: case 'doc.type.boolean'
+ : case 'doc.type.code'
: call(function (source)
vm.setNode(source, source)
end)
@@ -1299,6 +1616,10 @@ local compilerSwitch = util.switch()
: call(function (source)
vm.setNode(source, vm.compileNode(source.parent))
end)
+ : case 'doc.enum.name'
+ : call(function (source)
+ vm.setNode(source, vm.compileNode(source.parent))
+ end)
: case 'doc.field'
: call(function (source)
if not source.extends then
@@ -1337,18 +1658,14 @@ local compilerSwitch = util.switch()
end)
: case '...'
: call(function (source)
- local func = source.parent.parent
- if func.type ~= 'function' then
- return
- end
- if not func.bindDocs then
+ if not source.bindDocs then
return
end
- for _, doc in ipairs(func.bindDocs) do
+ for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.vararg' then
vm.setNode(source, vm.compileNode(doc))
end
- if doc.type == 'doc.param' and doc.param[1] == '...' then
+ if doc.type == 'doc.param' then
vm.setNode(source, vm.compileNode(doc))
end
end
@@ -1361,7 +1678,7 @@ local compilerSwitch = util.switch()
: call(function (source)
local type = vm.getGlobal('type', source[1])
if type then
- vm.setNode(source, vm.compileNode(type))
+ vm.setNode(source, type)
end
end)
: case 'doc.type.arg'
@@ -1375,10 +1692,6 @@ local compilerSwitch = util.switch()
vm.getNode(source):addOptional()
end
end)
- : case 'generic'
- : call(function (source)
- vm.setNode(source, source)
- end)
: case 'unary'
: call(function (source)
if bindAs(source) then
@@ -1387,58 +1700,7 @@ local compilerSwitch = util.switch()
if not source[1] then
return
end
- if source.op.type == 'not' then
- local result = vm.test(source[1])
- if result == nil then
- vm.setNode(source, vm.declareGlobal('type', 'boolean'))
- return
- else
- vm.setNode(source, {
- type = 'boolean',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = not result,
- })
- return
- end
- end
- if source.op.type == '#' then
- vm.setNode(source, vm.declareGlobal('type', 'integer'))
- return
- end
- if source.op.type == '-' then
- local v = vm.getNumber(source[1])
- if v == nil then
- vm.setNode(source, vm.declareGlobal('type', 'number'))
- return
- else
- vm.setNode(source, {
- type = 'number',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = -v,
- })
- return
- end
- end
- if source.op.type == '~' then
- local v = vm.getInteger(source[1])
- if v == nil then
- vm.setNode(source, vm.declareGlobal('type', 'integer'))
- return
- else
- vm.setNode(source, {
- type = 'integer',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = ~v,
- })
- return
- end
- end
+ vm.unarySwich(source.op.type, source)
end)
: case 'binary'
: call(function (source)
@@ -1448,323 +1710,21 @@ local compilerSwitch = util.switch()
if not source[1] or not source[2] then
return
end
- if source.op.type == 'and' then
- local node1 = vm.compileNode(source[1])
- local node2 = vm.compileNode(source[2])
- local r1 = vm.test(source[1])
- if r1 == true then
- vm.setNode(source, node2)
- elseif r1 == false then
- vm.setNode(source, node1)
- else
- vm.setNode(source, node2)
- end
- end
- if source.op.type == 'or' then
- local node1 = vm.compileNode(source[1])
- local node2 = vm.compileNode(source[2])
- local r1 = vm.test(source[1])
- if r1 == true then
- vm.setNode(source, node1)
- elseif r1 == false then
- vm.setNode(source, node2)
- else
- vm.getNode(source):merge(node1)
- vm.getNode(source):setTruthy()
- vm.getNode(source):merge(node2)
- end
- end
- if source.op.type == '==' then
- local result = vm.equal(source[1], source[2])
- if result == nil then
- vm.setNode(source, vm.declareGlobal('type', 'boolean'))
- return
- else
- vm.setNode(source, {
- type = 'boolean',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = result,
- })
- return
- end
- end
- if source.op.type == '~=' then
- local result = vm.equal(source[1], source[2])
- if result == nil then
- vm.setNode(source, vm.declareGlobal('type', 'boolean'))
- return
- else
- vm.setNode(source, {
- type = 'boolean',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = not result,
- })
- return
- end
- end
- if source.op.type == '<<' then
- local a = vm.getInteger(source[1])
- local b = vm.getInteger(source[2])
- if a and b then
- vm.setNode(source, {
- type = 'integer',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = a << b,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'integer'))
- return
- end
- end
- if source.op.type == '>>' then
- local a = vm.getInteger(source[1])
- local b = vm.getInteger(source[2])
- if a and b then
- vm.setNode(source, {
- type = 'integer',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = a >> b,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'integer'))
- return
- end
- end
- if source.op.type == '&' then
- local a = vm.getInteger(source[1])
- local b = vm.getInteger(source[2])
- if a and b then
- vm.setNode(source, {
- type = 'integer',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = a & b,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'integer'))
- return
- end
- end
- if source.op.type == '|' then
- local a = vm.getInteger(source[1])
- local b = vm.getInteger(source[2])
- if a and b then
- vm.setNode(source, {
- type = 'integer',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = a | b,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'integer'))
- return
- end
- end
- if source.op.type == '~' then
- local a = vm.getInteger(source[1])
- local b = vm.getInteger(source[2])
- if a and b then
- vm.setNode(source, {
- type = 'integer',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = a ~ b,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'integer'))
- return
- end
- end
- if source.op.type == '+' then
- local a = vm.getNumber(source[1])
- local b = vm.getNumber(source[2])
- if a and b then
- local result = a + b
- vm.setNode(source, {
- type = math.type(result) == 'integer' and 'integer' or 'number',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = result,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'number'))
- return
- end
- end
- if source.op.type == '-' then
- local a = vm.getNumber(source[1])
- local b = vm.getNumber(source[2])
- if a and b then
- local result = a - b
- vm.setNode(source, {
- type = math.type(result) == 'integer' and 'integer' or 'number',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = result,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'number'))
- return
- end
- end
- if source.op.type == '*' then
- local a = vm.getNumber(source[1])
- local b = vm.getNumber(source[2])
- if a and b then
- local result = a * b
- vm.setNode(source, {
- type = math.type(result) == 'integer' and 'integer' or 'number',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = result,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'number'))
- return
- end
- end
- if source.op.type == '/' then
- local a = vm.getNumber(source[1])
- local b = vm.getNumber(source[2])
- if a and b then
- vm.setNode(source, {
- type = 'number',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = a / b,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'number'))
- return
- end
- end
- if source.op.type == '%' then
- local a = vm.getNumber(source[1])
- local b = vm.getNumber(source[2])
- if a and b and b ~= 0 then
- local result = a % b
- vm.setNode(source, {
- type = math.type(result) == 'integer' and 'integer' or 'number',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = result,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'number'))
- return
- end
- end
- if source.op.type == '^' then
- local a = vm.getNumber(source[1])
- local b = vm.getNumber(source[2])
- if a and b then
- vm.setNode(source, {
- type = 'number',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = a ^ b,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'number'))
- return
- end
- end
- if source.op.type == '//' then
- local a = vm.getNumber(source[1])
- local b = vm.getNumber(source[2])
- if a and b and b ~= 0 then
- local result = a // b
- vm.setNode(source, {
- type = math.type(result) == 'integer' and 'integer' or 'number',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = result,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'number'))
- return
- end
- end
- if source.op.type == '..' then
- local a = vm.getString(source[1])
- or vm.getNumber(source[1])
- local b = vm.getString(source[2])
- or vm.getNumber(source[2])
- if a and b then
- if type(a) == 'number' or type(b) == 'number' then
- local uri = guide.getUri(source)
- local version = config.get(uri, 'Lua.runtime.version')
- if math.tointeger(a) and math.type(a) == 'float' then
- if version == 'Lua 5.3' or version == 'Lua 5.4' then
- a = ('%.1f'):format(a)
- else
- a = ('%.0f'):format(a)
- end
- end
- if math.tointeger(b) and math.type(b) == 'float' then
- if version == 'Lua 5.3' or version == 'Lua 5.4' then
- b = ('%.1f'):format(b)
- else
- b = ('%.0f'):format(b)
- end
- end
- end
- vm.setNode(source, {
- type = 'string',
- start = source.start,
- finish = source.finish,
- parent = source,
- [1] = a .. b,
- })
- return
- else
- vm.setNode(source, vm.declareGlobal('type', 'string'))
- return
- end
- end
+ vm.binarySwitch(source.op.type, source)
end)
----@param source vm.object
+---@param source parser.object
local function compileByNode(source)
compilerSwitch(source.type, source)
end
----@param source vm.object
+---@param source parser.object
local function compileByGlobal(source)
local global = source._globalNode
if not global then
return
end
+ ---@cast source parser.object
local root = guide.getRoot(source)
local uri = guide.getUri(source)
if not root._globalBase then
@@ -1800,22 +1760,30 @@ local function compileByGlobal(source)
if global.cate == 'variable' then
local hasMarkDoc
for _, set in ipairs(global:getSets(uri)) do
- if set.bindDocs then
+ if set.bindDocs and set.parent.type == 'main' then
if bindDocs(set) then
globalNode:merge(vm.compileNode(set))
hasMarkDoc = true
end
+ if vm.getNode(set) then
+ globalNode:merge(vm.compileNode(set))
+ end
end
end
+ -- Set all globals node first to avoid recursive
for _, set in ipairs(global:getSets(uri)) do
- if set.value then
+ vm.setNode(set, globalNode, true)
+ end
+ for _, set in ipairs(global:getSets(uri)) do
+ if set.value and set.value.type ~= 'nil' and set.parent.type == 'main' then
if not hasMarkDoc or guide.isLiteral(set.value) then
- if set.value.type ~= 'nil' then
- globalNode:merge(vm.compileNode(set.value))
- end
+ globalNode:merge(vm.compileNode(set.value))
end
end
end
+ for _, set in ipairs(global:getSets(uri)) do
+ vm.setNode(set, globalNode, true)
+ end
end
if global.cate == 'type' then
for _, set in ipairs(global:getSets(uri)) do
@@ -1850,21 +1818,27 @@ function vm.compileNode(source)
end
end
- if source.type == 'global' then
- return source
- end
-
local cache = vm.getNode(source)
if cache ~= nil then
return cache
end
- local node = vm.createNode()
- vm.setNode(source, node, true)
+ if source.type == 'generic' then
+ vm.setNode(source, source)
+ local node = vm.getNode(source)
+ ---@cast node -?
+ return node
+ end
+
+ ---@cast source parser.object
+ vm.setNode(source, vm.createNode(), true)
+ LOCK[source] = true
compileByGlobal(source)
compileByNode(source)
+ matchCall(source)
+ LOCK[source] = nil
- node = vm.getNode(source)
-
+ local node = vm.getNode(source)
+ ---@cast node -?
return node
end
diff --git a/script/vm/def.lua b/script/vm/def.lua
index 83e92686..f557f221 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -5,72 +5,7 @@ local guide = require 'parser.guide'
local simpleSwitch
-local function searchGetLocal(source, node, pushResult)
- local key = guide.getKeyName(source)
- for _, ref in ipairs(node.node.ref) do
- if ref.type == 'getlocal'
- and ref.next
- and guide.isSet(ref.next)
- and guide.getKeyName(ref.next) == key then
- pushResult(ref.next)
- end
- end
-end
-
simpleSwitch = util.switch()
- : case 'local'
- : call(function (source, pushResult)
- pushResult(source)
- if source.ref then
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal' then
- pushResult(ref)
- end
- end
- end
- end)
- : case 'sellf'
- : call(function (source, pushResult)
- if source.ref then
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal' then
- pushResult(ref)
- end
- end
- end
- for _, res in ipairs(vm.getDefs(source.method.node)) do
- pushResult(res)
- end
- end)
- : case 'getlocal'
- : case 'setlocal'
- : call(function (source, pushResult)
- simpleSwitch('local', source.node, pushResult)
- end)
- : case 'field'
- : call(function (source, pushResult)
- local parent = source.parent
- if parent.type ~= 'tablefield' then
- simpleSwitch(parent.type, parent, pushResult)
- end
- end)
- : case 'setfield'
- : case 'getfield'
- : call(function (source, pushResult)
- local node = source.node
- if node.type == 'getlocal' then
- searchGetLocal(source, node, pushResult)
- return
- end
- end)
- : case 'getindex'
- : case 'setindex'
- : call(function (source, pushResult)
- local node = source.node
- if node.type == 'getlocal' then
- searchGetLocal(source, node, pushResult)
- end
- end)
: case 'goto'
: call(function (source, pushResult)
if source.node then
@@ -98,7 +33,7 @@ local searchFieldSwitch = util.switch()
end
end)
: case 'global'
- ---@param obj vm.object
+ ---@param obj vm.global
---@param key string
: call(function (suri, obj, key, pushResult)
if obj.cate == 'variable' then
@@ -115,12 +50,10 @@ local searchFieldSwitch = util.switch()
end)
: case 'local'
: call(function (suri, obj, key, pushResult)
- local sources = vm.getLocalSources(obj, key)
+ local sources = vm.getLocalSourcesSets(obj, key)
if sources then
for _, src in ipairs(sources) do
- if guide.isSet(src) then
- pushResult(src)
- end
+ pushResult(src)
end
end
end)
@@ -152,7 +85,7 @@ local nodeSwitch;nodeSwitch = util.switch()
local parentNode = vm.compileNode(source.node)
local uri = guide.getUri(source)
local key = guide.getKeyName(source)
- if not key then
+ if type(key) ~= 'string' then
return
end
if lastKey then
@@ -169,9 +102,15 @@ local nodeSwitch;nodeSwitch = util.switch()
if lastKey then
return
end
- local tbl = source.parent
+ local key = guide.getKeyName(source)
+ if type(key) ~= 'string' then
+ return
+ end
local uri = guide.getUri(source)
- searchFieldSwitch(tbl.type, uri, tbl, guide.getKeyName(source), pushResult)
+ local parentNode = vm.compileNode(source.node)
+ for pn in parentNode:eachObject() do
+ searchFieldSwitch(pn.type, uri, pn, key, pushResult)
+ end
end)
: case 'doc.see.field'
: call(function (source, lastKey, pushResult)
@@ -194,14 +133,12 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
- local idSources = vm.getLocalSources(source)
+ local idSources = vm.getLocalSourcesSets(source)
if not idSources then
return
end
for _, src in ipairs(idSources) do
- if guide.isSet(src) then
- pushResult(src)
- end
+ pushResult(src)
end
end
diff --git a/script/vm/doc.lua b/script/vm/doc.lua
index e2b383b6..293cf5c3 100644
--- a/script/vm/doc.lua
+++ b/script/vm/doc.lua
@@ -20,6 +20,8 @@ function vm.getDocSets(suri, name)
end
end
+---@param uri uri
+---@return boolean
function vm.isMetaFile(uri)
local status = files.getState(uri)
if not status then
@@ -45,6 +47,8 @@ function vm.isMetaFile(uri)
return false
end
+---@param doc parser.object
+---@return table<string, boolean>?
function vm.getValidVersions(doc)
if doc.type ~= 'doc.version' then
return
@@ -87,13 +91,14 @@ function vm.getValidVersions(doc)
return valids
end
+---@param value parser.object
---@return parser.object?
local function getDeprecated(value)
if not value.bindDocs then
- return false
+ return nil
end
if value._deprecated ~= nil then
- return value._deprecated
+ return value._deprecated or nil
end
for _, doc in ipairs(value.bindDocs) do
if doc.type == 'doc.deprecated' then
@@ -101,24 +106,26 @@ local function getDeprecated(value)
return doc
elseif doc.type == 'doc.version' then
local valids = vm.getValidVersions(doc)
- if not valids[config.get(guide.getUri(value), 'Lua.runtime.version')] then
+ if valids and not valids[config.get(guide.getUri(value), 'Lua.runtime.version')] then
value._deprecated = doc
return doc
end
end
end
value._deprecated = false
- return false
+ return nil
end
+---@param value parser.object
+---@param deep boolean?
---@return parser.object?
function vm.getDeprecated(value, deep)
if deep then
local defs = vm.getDefs(value)
if #defs == 0 then
- return false
+ return nil
end
- local deprecated = false
+ local deprecated
for _, def in ipairs(defs) do
if def.type == 'setglobal'
or def.type == 'setfield'
@@ -128,7 +135,7 @@ function vm.getDeprecated(value, deep)
or def.type == 'tableindex' then
deprecated = getDeprecated(def)
if not deprecated then
- return false
+ return nil
end
end
end
@@ -138,6 +145,8 @@ function vm.getDeprecated(value, deep)
end
end
+---@param value parser.object
+---@return boolean
local function isAsync(value)
if value.type == 'function' then
if not value.bindDocs then
@@ -155,9 +164,15 @@ local function isAsync(value)
value._async = false
return false
end
+ if value.type == 'main' then
+ return true
+ end
return value.async == true
end
+---@param value parser.object
+---@param deep boolean?
+---@return boolean
function vm.isAsync(value, deep)
if isAsync(value) then
return true
@@ -176,6 +191,8 @@ function vm.isAsync(value, deep)
return false
end
+---@param value parser.object
+---@return boolean
local function isNoDiscard(value)
if value.type == 'function' then
if not value.bindDocs then
@@ -196,6 +213,9 @@ local function isNoDiscard(value)
return false
end
+---@param value parser.object
+---@param deep boolean?
+---@return boolean
function vm.isNoDiscard(value, deep)
if isNoDiscard(value) then
return true
@@ -214,6 +234,8 @@ function vm.isNoDiscard(value, deep)
return false
end
+---@param param parser.object
+---@return boolean
local function isCalledInFunction(param)
if not param.ref then
return false
@@ -238,6 +260,9 @@ local function isCalledInFunction(param)
return false
end
+---@param node parser.object
+---@param index integer
+---@return boolean
local function isLinkedCall(node, index)
for _, def in ipairs(vm.getDefs(node)) do
if def.type == 'function' then
@@ -252,16 +277,21 @@ local function isLinkedCall(node, index)
return false
end
+---@param node parser.object
+---@param index integer
+---@return boolean
function vm.isLinkedCall(node, index)
return isLinkedCall(node, index)
end
+---@param call parser.object
+---@return boolean
function vm.isAsyncCall(call)
if vm.isAsync(call.node, true) then
return true
end
if not call.args then
- return
+ return false
end
for i, arg in ipairs(call.args) do
if vm.isAsync(arg, true)
@@ -272,6 +302,9 @@ function vm.isAsyncCall(call)
return false
end
+---@param uri uri
+---@param doc parser.object
+---@param results table[]
local function makeDiagRange(uri, doc, results)
local names
if doc.names then
@@ -325,7 +358,12 @@ local function makeDiagRange(uri, doc, results)
end
end
-function vm.isDiagDisabledAt(uri, position, name)
+---@param uri uri
+---@param position integer
+---@param name string
+---@param err? boolean
+---@return boolean
+function vm.isDiagDisabledAt(uri, position, name, err)
local status = files.getState(uri)
if not status then
return false
@@ -355,7 +393,8 @@ function vm.isDiagDisabledAt(uri, position, name)
local count = 0
for _, range in ipairs(cache.diagnosticRanges) do
if range.row <= myRow then
- if not range.names or range.names[name] then
+ if (range.names and range.names[name])
+ or (not range.names and not err) then
if range.mode == 'disable' then
count = count + 1
elseif range.mode == 'enable' then
diff --git a/script/vm/field.lua b/script/vm/field.lua
index 5de838be..b92c3a7b 100644
--- a/script/vm/field.lua
+++ b/script/vm/field.lua
@@ -16,7 +16,7 @@ local searchByNodeSwitch = util.switch()
end)
local function searchByLocalID(source, pushResult)
- local fields = vm.getLocalFields(source)
+ local fields = vm.getLocalFields(source, true)
if fields then
for _, field in ipairs(fields) do
pushResult(field)
diff --git a/script/vm/function.lua b/script/vm/function.lua
new file mode 100644
index 00000000..7cde6298
--- /dev/null
+++ b/script/vm/function.lua
@@ -0,0 +1,245 @@
+---@class vm
+local vm = require 'vm.vm'
+
+---@param arg parser.object
+---@return parser.object?
+local function getDocParam(arg)
+ if not arg.bindDocs then
+ return nil
+ end
+ for _, doc in ipairs(arg.bindDocs) do
+ if doc.type == 'doc.param'
+ and doc.param[1] == arg[1] then
+ return doc
+ end
+ end
+ return nil
+end
+
+---@param func parser.object
+---@return integer min
+---@return number max
+---@return integer def
+function vm.countParamsOfFunction(func)
+ local min = 0
+ local max = 0
+ local def = 0
+ if func.type == 'function' then
+ if func.args then
+ max = #func.args
+ def = max
+ for i = #func.args, 1, -1 do
+ local arg = func.args[i]
+ if arg.type == '...' then
+ max = math.huge
+ elseif arg.type == 'self'
+ and i == 1 then
+ min = i
+ break
+ elseif getDocParam(arg)
+ and not vm.compileNode(arg):isNullable() then
+ min = i
+ break
+ end
+ end
+ end
+ end
+ if func.type == 'doc.type.function' then
+ if func.args then
+ max = #func.args
+ def = max
+ for i = #func.args, 1, -1 do
+ local arg = func.args[i]
+ if arg.name and arg.name[1] =='...' then
+ max = math.huge
+ elseif not vm.compileNode(arg):isNullable() then
+ min = i
+ break
+ end
+ end
+ end
+ end
+ return min, max, def
+end
+
+---@param node vm.node
+---@return integer min
+---@return number max
+---@return integer def
+function vm.countParamsOfNode(node)
+ local min, max, def
+ for n in node:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ ---@cast n parser.object
+ local fmin, fmax, fdef = vm.countParamsOfFunction(n)
+ if not min or fmin < min then
+ min = fmin
+ end
+ if not max or fmax > max then
+ max = fmax
+ end
+ if not def or fdef > def then
+ def = fdef
+ end
+ end
+ end
+ return min or 0, max or math.huge, def or 0
+end
+
+---@param func parser.object
+---@param onlyDoc? boolean
+---@param mark? table
+---@return integer min
+---@return number max
+---@return integer def
+function vm.countReturnsOfFunction(func, onlyDoc, mark)
+ if func.type == 'function' then
+ ---@type integer?, number?, integer?
+ local min, max, def
+ local hasDocReturn
+ if func.bindDocs then
+ local lastReturn
+ local n = 0
+ ---@type integer?, number?, integer?
+ local dmin, dmax, ddef
+ for _, doc in ipairs(func.bindDocs) do
+ if doc.type == 'doc.return' then
+ hasDocReturn = true
+ for _, ret in ipairs(doc.returns) do
+ n = n + 1
+ lastReturn = ret
+ dmax = n
+ ddef = n
+ if (not ret.name or ret.name[1] ~= '...')
+ and not vm.compileNode(ret):isNullable() then
+ dmin = n
+ end
+ end
+ end
+ end
+ if lastReturn then
+ if lastReturn.name and lastReturn.name[1] == '...' then
+ dmax = math.huge
+ end
+ end
+ if dmin and (not min or (dmin < min)) then
+ min = dmin
+ end
+ if dmax and (not max or (dmax > max)) then
+ max = dmax
+ end
+ if ddef and (not def or (ddef > def)) then
+ def = ddef
+ end
+ end
+ if not onlyDoc and not hasDocReturn and func.returns then
+ for _, ret in ipairs(func.returns) do
+ local rmin, rmax, ddef = vm.countList(ret, mark)
+ if not min or rmin < min then
+ min = rmin
+ end
+ if not max or rmax > max then
+ max = rmax
+ end
+ if not def or ddef > def then
+ def = ddef
+ end
+ end
+ end
+ return min or 0, max or math.huge, def or 0
+ end
+ if func.type == 'doc.type.function' then
+ return vm.countList(func.returns)
+ end
+ error('not a function')
+end
+
+---@param func parser.object
+---@param mark? table
+---@return integer min
+---@return number max
+---@return integer def
+function vm.countReturnsOfCall(func, args, mark)
+ local funcs = vm.getMatchedFunctions(func, args, mark)
+ ---@type integer?, number?, integer?
+ local min, max, def
+ for _, f in ipairs(funcs) do
+ local rmin, rmax, rdef = vm.countReturnsOfFunction(f, false, mark)
+ if not min or rmin < min then
+ min = rmin
+ end
+ if not max or rmax > max then
+ max = rmax
+ end
+ if not def or rdef > def then
+ def = rdef
+ end
+ end
+ return min or 0, max or math.huge, def or 0
+end
+
+---@param list parser.object[]?
+---@param mark? table
+---@return integer min
+---@return number max
+---@return integer def
+function vm.countList(list, mark)
+ if not list then
+ return 0, 0, 0
+ end
+ local lastArg = list[#list]
+ if not lastArg then
+ return 0, 0, 0
+ end
+ if lastArg.type == '...'
+ or lastArg.type == 'varargs' then
+ return #list - 1, math.huge, #list
+ end
+ if lastArg.type == 'call' then
+ if not mark then
+ mark = {}
+ end
+ if mark[lastArg] then
+ return #list - 1, math.huge, #list
+ end
+ mark[lastArg] = true
+ local rmin, rmax, rdef = vm.countReturnsOfCall(lastArg.node, lastArg.args, mark)
+ return #list - 1 + rmin, #list - 1 + rmax, #list - 1 + rdef
+ end
+ return #list, #list, #list
+end
+
+---@param func parser.object
+---@param args parser.object[]?
+---@param mark? table
+---@return parser.object[]
+function vm.getMatchedFunctions(func, args, mark)
+ local funcs = {}
+ local node = vm.compileNode(func)
+ for n in node:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ funcs[#funcs+1] = n
+ end
+ end
+ if #funcs <= 1 then
+ return funcs
+ end
+
+ local amin, amax = vm.countList(args, mark)
+
+ local matched = {}
+ for _, n in ipairs(funcs) do
+ local min, max = vm.countParamsOfFunction(n)
+ if amin >= min and amax <= max then
+ matched[#matched+1] = n
+ end
+ end
+
+ if #matched == 0 then
+ return funcs
+ else
+ return matched
+ end
+end
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index 6462028e..544e11c9 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -11,11 +11,11 @@ local mt = {}
mt.__index = mt
mt.type = 'generic'
----@param source parser.object
+---@param source vm.object?
---@param resolved? table<string, vm.node>
----@return parser.object | vm.node
+---@return vm.object?
local function cloneObject(source, resolved)
- if not resolved then
+ if not resolved or not source then
return source
end
if source.type == 'doc.generic.name' then
@@ -121,8 +121,17 @@ function mt:resolve(uri, args)
local protoNode = vm.compileNode(self.proto)
local result = vm.createNode()
for nd in protoNode:eachObject() do
- local clonedNode = vm.compileNode(cloneObject(nd, resolved))
- result:merge(clonedNode)
+ if nd.type == 'global' then
+ ---@cast nd vm.global
+ result:merge(nd)
+ else
+ ---@cast nd -vm.global
+ local clonedObject = cloneObject(nd, resolved)
+ if clonedObject then
+ local clonedNode = vm.compileNode(clonedObject)
+ result:merge(clonedNode)
+ end
+ end
end
return result
end
diff --git a/script/vm/global.lua b/script/vm/global.lua
index a54ab552..22235681 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -2,6 +2,7 @@ local util = require 'utility'
local scope = require 'workspace.scope'
local guide = require 'parser.guide'
local files = require 'files'
+local ws = require 'workspace'
---@class vm
local vm = require 'vm.vm'
@@ -11,8 +12,8 @@ local vm = require 'vm.vm'
---@class vm.global
---@field links table<uri, vm.global.link>
----@field setsCache table<uri, parser.object[]>
----@field getsCache table<uri, parser.object[]>
+---@field setsCache? table<uri, parser.object[]>
+---@field getsCache? table<uri, parser.object[]>
---@field cate vm.global.cate
local mt = {}
mt.__index = mt
@@ -41,6 +42,7 @@ function mt:addGet(uri, source)
self.getsCache = nil
end
+---@param suri uri
---@return parser.object[]
function mt:getSets(suri)
if not self.setsCache then
@@ -127,7 +129,8 @@ local function createGlobal(name, cate)
end
---@class parser.object
----@field _globalNode vm.global
+---@field _globalNode vm.global|false
+---@field _enums? (string|integer)[]
---@type table<string, vm.global>
local allGlobals = {}
@@ -161,6 +164,9 @@ local compilerGlobalSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
local global = vm.declareGlobal('variable', name, uri)
global:addSet(uri, source)
source._globalNode = global
@@ -169,6 +175,9 @@ local compilerGlobalSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
local global = vm.declareGlobal('variable', name, uri)
global:addGet(uri, source)
source._globalNode = global
@@ -271,6 +280,9 @@ local compilerGlobalSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
local class = vm.declareGlobal('type', name, uri)
class:addSet(uri, source)
source._globalNode = class
@@ -293,6 +305,9 @@ local compilerGlobalSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
local alias = vm.declareGlobal('type', name, uri)
alias:addSet(uri, source)
source._globalNode = alias
@@ -305,10 +320,51 @@ local compilerGlobalSwitch = util.switch()
source.extends._generic = vm.createGeneric(source.extends, source._sign)
end
end)
+ : case 'doc.enum'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
+ local enum = vm.declareGlobal('type', name, uri)
+ enum:addSet(uri, source)
+ source._globalNode = enum
+
+ local tbl = source.bindSource
+ if not tbl then
+ return
+ end
+ source._enums = {}
+ for _, field in ipairs(tbl) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ if not field.value then
+ goto CONTINUE
+ end
+ local key = guide.getKeyName(field)
+ if not key then
+ goto CONTINUE
+ end
+ if field.value.type == 'integer'
+ or field.value.type == 'string' then
+ source._enums[#source._enums+1] = field.value[1]
+ end
+ if field.value.type == 'binary'
+ or field.value.type == 'unary' then
+ source._enums[#source._enums+1] = vm.getNumber(field.value)
+ end
+ ::CONTINUE::
+ end
+ end
+ end)
: case 'doc.type.name'
: call(function (source)
local uri = guide.getUri(source)
local name = source[1]
+ if name == '_' then
+ return
+ end
local type = vm.declareGlobal('type', name, uri)
type:addGet(uri, source)
source._globalNode = type
@@ -446,7 +502,7 @@ local function compileSelf(source)
if not node then
return
end
- local fields = vm.getLocalFields(source)
+ local fields = vm.getLocalFields(source, false)
if not fields then
return
end
@@ -491,6 +547,7 @@ local function compileAst(source)
'doc.alias',
'doc.type.name',
'doc.extends.name',
+ 'doc.enum',
}, function (src)
compileObject(src)
end)
@@ -530,9 +587,11 @@ for uri in files.eachFile() do
end
end
+---@async
files.watch(function (ev, uri)
if ev == 'update' then
dropUri(uri)
+ ws.awaitReady(uri)
local state = files.getState(uri)
if state then
compileAst(state.ast)
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index fabc9828..263b2500 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -8,10 +8,9 @@ local vm = require 'vm.vm'
---@field views table<string, boolean>
---@field cachedView? string
---@field node? vm.node
----@field uri? uri
+---@field _drop table
local mt = {}
mt.__index = mt
-mt._hasNumber = false
mt._hasTable = false
mt._hasClass = false
mt._hasFunctionDef = false
@@ -21,6 +20,8 @@ mt._isLocal = false
vm.NULL = setmetatable({}, mt)
+local LOCK = {}
+
local inferSorted = {
['boolean'] = - 100,
['string'] = - 99,
@@ -43,14 +44,13 @@ local viewNodeSwitch = util.switch()
end)
: case 'number'
: call(function (source, infer)
- infer._hasNumber = true
return source.type
end)
: case 'table'
- : call(function (source, infer)
+ : call(function (source, infer, uri)
if source.type == 'table' then
if #source == 1 and source[1].type == 'varargs' then
- local node = vm.getInfer(source[1]):view()
+ local node = vm.getInfer(source[1]):view(uri)
return ('%s[]'):format(node)
end
end
@@ -76,19 +76,18 @@ local viewNodeSwitch = util.switch()
: case 'global'
: call(function (source, infer)
if source.cate == 'type' then
- infer._hasClass = true
- if source.name == 'number' then
- infer._hasNumber = true
+ if not guide.isBasicType(source.name) then
+ infer._hasClass = true
end
return source.name
end
end)
: case 'doc.type.name'
- : call(function (source, infer)
+ : call(function (source, infer, uri)
if source.signs then
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = vm.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view(uri)
end
return ('%s<%s>'):format(source[1], table.concat(buf, ', '))
else
@@ -96,34 +95,68 @@ local viewNodeSwitch = util.switch()
end
end)
: case 'generic'
- : call(function (source, infer)
- return vm.getInfer(source.proto):view()
+ : call(function (source, infer, uri)
+ return vm.getInfer(source.proto):view(uri)
end)
: case 'doc.generic.name'
: call(function (source, infer)
return ('<%s>'):format(source[1])
end)
: case 'doc.type.array'
- : call(function (source, infer)
+ : call(function (source, infer, uri)
infer._hasClass = true
- local view = vm.getInfer(source.node):view()
+ local view = vm.getInfer(source.node):view(uri)
if source.node.type == 'doc.type' then
view = '(' .. view .. ')'
end
return view .. '[]'
end)
: case 'doc.type.sign'
- : call(function (source, infer)
+ : call(function (source, infer, uri)
infer._hasClass = true
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = vm.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view(uri)
+ end
+ if infer._drop then
+ local node = vm.compileNode(source)
+ for c in node:eachObject() do
+ if guide.isLiteral(c) then
+ infer._drop[c] = true
+ end
+ end
end
return ('%s<%s>'):format(source.node[1], table.concat(buf, ', '))
end)
: case 'doc.type.table'
- : call(function (source, infer)
- infer._hasTable = true
+ : call(function (source, infer, uri)
+ if #source.fields == 0 then
+ infer._hasTable = true
+ return
+ end
+ if infer._drop and infer._drop[source] then
+ infer._hasTable = true
+ return
+ end
+ infer._hasClass = true
+ local buf = {}
+ buf[#buf+1] = '{ '
+ for i, field in ipairs(source.fields) do
+ if i > 1 then
+ buf[#buf+1] = ', '
+ end
+ local key = field.name
+ if key.type == 'doc.type' then
+ buf[#buf+1] = ('[%s]: '):format(vm.getInfer(key):view(uri))
+ elseif type(key[1]) == 'string' then
+ buf[#buf+1] = key[1] .. ': '
+ else
+ buf[#buf+1] = ('[%q]: '):format(key[1])
+ end
+ buf[#buf+1] = vm.getInfer(field.extends):view(uri)
+ end
+ buf[#buf+1] = ' }'
+ return table.concat(buf)
end)
: case 'doc.type.string'
: call(function (source, infer)
@@ -134,8 +167,12 @@ local viewNodeSwitch = util.switch()
: call(function (source, infer)
return ('%q'):format(source[1])
end)
- : case 'doc.type.function'
+ : case 'doc.type.code'
: call(function (source, infer)
+ return ('`%s`'):format(source[1])
+ end)
+ : case 'doc.type.function'
+ : call(function (source, infer, uri)
infer._hasDocFunction = true
local args = {}
local rets = {}
@@ -148,31 +185,53 @@ local viewNodeSwitch = util.switch()
argNode = argNode:copy()
argNode:removeOptional()
end
- args[i] = string.format('%s%s: %s'
+ args[i] = string.format('%s%s%s%s'
, arg.name[1]
, isOptional and '?' or ''
- , vm.getInfer(argNode):view()
+ , arg.name[1] == '...' and '' or ': '
+ , vm.getInfer(argNode):view(uri)
)
end
if #args > 0 then
argView = table.concat(args, ', ')
end
+ local needReturnParen
for i, ret in ipairs(source.returns) do
- rets[i] = vm.getInfer(ret):view()
+ local retType = vm.getInfer(ret):view(uri)
+ if ret.name then
+ if ret.name[1] == '...' then
+ rets[i] = ('%s%s'):format(ret.name[1], retType)
+ else
+ needReturnParen = true
+ rets[i] = ('%s: %s'):format(ret.name[1], retType)
+ end
+ else
+ rets[i] = retType
+ end
end
if #rets > 0 then
- regView = ':' .. table.concat(rets, ', ')
+ if needReturnParen then
+ regView = (':(%s)'):format(table.concat(rets, ', '))
+ else
+ regView = (':%s'):format(table.concat(rets, ', '))
+ end
end
return ('fun(%s)%s'):format(argView, regView)
end)
----@param source parser.object | vm.node
+---@class vm.node
+---@field lastInfer? vm.infer
+
+---@param source vm.object | vm.node
---@return vm.infer
function vm.getInfer(source)
+ ---@type vm.node
local node
if source.type == 'vm.node' then
+ ---@cast source vm.node
node = source
else
+ ---@cast source vm.object
node = vm.compileNode(source)
end
if node.lastInfer then
@@ -180,7 +239,7 @@ function vm.getInfer(source)
end
local infer = setmetatable({
node = node,
- uri = source.type ~= 'vm.node' and guide.getUri(source),
+ _drop = {},
}, mt)
node.lastInfer = infer
@@ -188,9 +247,6 @@ function vm.getInfer(source)
end
function mt:_trim()
- if self._hasNumber then
- self.views['integer'] = nil
- end
if self._hasDocFunction then
if self._hasFunctionDef then
for view in pairs(self.views) do
@@ -205,6 +261,13 @@ function mt:_trim()
if self._hasTable and not self._hasClass then
self.views['table'] = true
end
+ if self.views['number'] then
+ self.views['integer'] = nil
+ end
+ if self.views['boolean'] then
+ self.views['true'] = nil
+ self.views['false'] = nil
+ end
end
---@param uri uri
@@ -214,46 +277,86 @@ function mt:_eraseAlias(uri)
local expandAlias = config.get(uri, 'Lua.hover.expandAlias')
for n in self.node:eachObject() do
if n.type == 'global' and n.cate == 'type' then
+ if LOCK[n.name] then
+ goto CONTINUE
+ end
+ LOCK[n.name] = true
for _, set in ipairs(n:getSets(uri)) do
if set.type == 'doc.alias' then
if expandAlias then
drop[n.name] = true
+ local newInfer = {}
+ for _, ext in ipairs(set.extends.types) do
+ viewNodeSwitch(ext.type, ext, newInfer, uri)
+ end
+ if newInfer._hasTable then
+ self.views['table'] = true
+ end
else
for _, ext in ipairs(set.extends.types) do
- local view = viewNodeSwitch(ext.type, ext, {})
+ local view = viewNodeSwitch(ext.type, ext, {}, uri)
if view and view ~= n.name then
drop[view] = true
end
end
end
end
+ if set.type == 'doc.class' then
+ if set.extends then
+ for _, ext in ipairs(set.extends) do
+ if ext.type == 'doc.extends.name' then
+ local view = ext[1]
+ drop[view] = true
+ end
+ end
+ end
+ end
end
+ LOCK[n.name] = nil
+ ::CONTINUE::
end
end
return drop
end
+---@param uri uri
---@param tp string
---@return boolean
-function mt:hasType(tp)
- self:_computeViews()
+function mt:hasType(uri, tp)
+ self:_computeViews(uri)
return self.views[tp] == true
end
+---@param uri uri
+function mt:hasUnknown(uri)
+ self:_computeViews(uri)
+ return not next(self.views)
+ or self.views['unknown'] == true
+end
+
+---@param uri uri
+function mt:hasAny(uri)
+ self:_computeViews(uri)
+ return self.views['any'] == true
+end
+
+---@param uri uri
---@return boolean
-function mt:hasClass()
- self:_computeViews()
+function mt:hasClass(uri)
+ self:_computeViews(uri)
return self._hasClass == true
end
+---@param uri uri
---@return boolean
-function mt:hasFunction()
- self:_computeViews()
+function mt:hasFunction(uri)
+ self:_computeViews(uri)
return self.views['function'] == true
or self._hasDocFunction == true
end
-function mt:_computeViews()
+---@param uri uri
+function mt:_computeViews(uri)
if self.views then
return
end
@@ -261,7 +364,7 @@ function mt:_computeViews()
self.views = {}
for n in self.node:eachObject() do
- local view = viewNodeSwitch(n.type, n, self)
+ local view = viewNodeSwitch(n.type, n, self, uri)
if view then
self.views[view] = true
end
@@ -270,11 +373,11 @@ function mt:_computeViews()
self:_trim()
end
+---@param uri uri
---@param default? string
----@param uri? uri
---@return string
-function mt:view(default, uri)
- self:_computeViews()
+function mt:view(uri, default)
+ self:_computeViews(uri)
if self.views['any'] then
return 'any'
@@ -282,7 +385,7 @@ function mt:view(default, uri)
local drop
if self._hasClass then
- drop = self:_eraseAlias(uri or self.uri)
+ drop = self:_eraseAlias(uri)
end
local array = {}
@@ -302,7 +405,7 @@ function mt:view(default, uri)
end)
local max = #array
- local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit')
+ local limit = config.get(uri, 'Lua.hover.enumsLimit')
local view
if #array == 0 then
@@ -329,8 +432,9 @@ function mt:view(default, uri)
return view
end
-function mt:eachView()
- self:_computeViews()
+---@param uri uri
+function mt:eachView(uri)
+ self:_computeViews(uri)
return next, self.views
end
@@ -346,7 +450,6 @@ function mt:merge(other)
local infer = setmetatable({
node = vm.createNode(self.node, other.node),
- uri = self.uri,
}, mt)
return infer
@@ -365,7 +468,7 @@ function mt:viewLiterals()
or n.type == 'integer'
or n.type == 'boolean' then
local literal = util.viewLiteral(n[1])
- if not mark[literal] then
+ if literal and not mark[literal] then
literals[#literals+1] = literal
mark[literal] = true
end
@@ -374,7 +477,14 @@ function mt:viewLiterals()
if #literals == 0 then
return nil
end
- table.sort(literals)
+ table.sort(literals, function (a, b)
+ local sa = inferSorted[a] or 0
+ local sb = inferSorted[b] or 0
+ if sa == sb then
+ return a < b
+ end
+ return sa < sb
+ end)
return table.concat(literals, '|')
end
@@ -401,8 +511,9 @@ function mt:viewClass()
return table.concat(class, '|')
end
----@param source parser.object
+---@param source vm.node.object
+---@param uri uri
---@return string?
-function vm.viewObject(source)
- return viewNodeSwitch(source.type, source, {})
+function vm.viewObject(source, uri)
+ return viewNodeSwitch(source.type, source, {}, uri)
end
diff --git a/script/vm/init.lua b/script/vm/init.lua
index f5003c11..87c046f0 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -1,6 +1,6 @@
local vm = require 'vm.vm'
----@alias vm.object parser.object | vm.global | vm.generic
+---@alias vm.object parser.object | vm.generic
require 'vm.compiler'
require 'vm.value'
@@ -17,4 +17,6 @@ require 'vm.generic'
require 'vm.sign'
require 'vm.local-id'
require 'vm.global'
+require 'vm.function'
+require 'vm.operator'
return vm
diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua
index 80c68769..9168d680 100644
--- a/script/vm/local-id.lua
+++ b/script/vm/local-id.lua
@@ -4,8 +4,8 @@ local guide = require 'parser.guide'
local vm = require 'vm.vm'
---@class parser.object
----@field _localID string
----@field _localIDs table<string, parser.object[]>
+---@field _localID string|false
+---@field _localIDs table<string, { sets: parser.object[], gets: parser.object[] }>
local compileLocalID, getLocal
@@ -21,6 +21,7 @@ local compileSwitch = util.switch()
compileLocalID(ref)
end
end)
+ : case 'setlocal'
: case 'getlocal'
: call(function (source)
source._localID = ('%d'):format(source.node.start)
@@ -114,10 +115,19 @@ end
function vm.insertLocalID(id, source)
local root = guide.getRoot(source)
if not root._localIDs then
- root._localIDs = util.multiTable(2)
+ root._localIDs = util.multiTable(2, function ()
+ return {
+ sets = {},
+ gets = {},
+ }
+ end)
end
local sources = root._localIDs[id]
- sources[#sources+1] = source
+ if guide.isSet(source) then
+ sources.sets[#sources.sets+1] = source
+ else
+ sources.gets[#sources.gets+1] = source
+ end
end
function compileLocalID(source)
@@ -137,7 +147,7 @@ function compileLocalID(source)
end
---@param source parser.object
----@return string?
+---@return string|false
function vm.getLocalID(source)
if source._localID ~= nil then
return source._localID
@@ -154,7 +164,28 @@ end
---@param source parser.object
---@param key? string
---@return parser.object[]?
-function vm.getLocalSources(source, key)
+function vm.getLocalSourcesSets(source, key)
+ local id = vm.getLocalID(source)
+ if not id then
+ return nil
+ end
+ local root = guide.getRoot(source)
+ if not root._localIDs then
+ return nil
+ end
+ if key then
+ if type(key) ~= 'string' then
+ return nil
+ end
+ id = id .. vm.ID_SPLITE .. key
+ end
+ return root._localIDs[id].sets
+end
+
+---@param source parser.object
+---@param key? string
+---@return parser.object[]?
+function vm.getLocalSourcesGets(source, key)
local id = vm.getLocalID(source)
if not id then
return nil
@@ -169,12 +200,13 @@ function vm.getLocalSources(source, key)
end
id = id .. vm.ID_SPLITE .. key
end
- return root._localIDs[id]
+ return root._localIDs[id].gets
end
---@param source parser.object
----@return parser.object[]
-function vm.getLocalFields(source)
+---@param includeGets boolean
+---@return parser.object[]?
+function vm.getLocalFields(source, includeGets)
local id = vm.getLocalID(source)
if not id then
return nil
@@ -192,9 +224,14 @@ function vm.getLocalFields(source)
and lid:sub(#id + 1, #id + 1) == vm.ID_SPLITE
-- only one field
and not lid:find(vm.ID_SPLITE, #id + 2) then
- for _, src in ipairs(sources) do
+ for _, src in ipairs(sources.sets) do
fields[#fields+1] = src
end
+ if includeGets then
+ for _, src in ipairs(sources.gets) do
+ fields[#fields+1] = src
+ end
+ end
end
end
local cost = os.clock() - clock
diff --git a/script/vm/node.lua b/script/vm/node.lua
index e76542aa..49207b13 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -2,28 +2,32 @@ local files = require 'files'
---@class vm
local vm = require 'vm.vm'
local ws = require 'workspace.workspace'
+local guide = require 'parser.guide'
---@type table<vm.object, vm.node>
vm.nodeCache = {}
+---@alias vm.node.object vm.object | vm.global
+
---@class vm.node
----@field [integer] vm.object
+---@field [integer] vm.node.object
+---@field [vm.node.object] true
local mt = {}
mt.__index = mt
mt.id = 0
mt.type = 'vm.node'
mt.optional = nil
-mt.lastInfer = nil
mt.data = nil
----@param node vm.node | vm.object
+---@param node vm.node | vm.node.object
+---@return vm.node
function mt:merge(node)
if not node then
- return
+ return self
end
if node.type == 'vm.node' then
if node == self then
- return
+ return self
end
if node:isOptional() then
self.optional = true
@@ -35,11 +39,13 @@ function mt:merge(node)
end
end
else
+ ---@cast node -vm.node
if not self[node] then
self[node] = true
self[#self+1] = node
end
end
+ return self
end
---@return boolean
@@ -56,7 +62,7 @@ function mt:clear()
end
---@param n integer
----@return vm.object?
+---@return vm.node.object?
function mt:get(n)
return self[n]
end
@@ -68,6 +74,7 @@ function mt:setData(k, v)
self.data[k] = v
end
+---@return any
function mt:getData(k)
if not self.data then
return nil
@@ -81,6 +88,7 @@ end
function mt:removeOptional()
self:remove 'nil'
+ return self
end
---@return boolean
@@ -106,6 +114,19 @@ function mt:hasFalsy()
end
---@return boolean
+function mt:hasKnownType()
+ for _, c in ipairs(self) do
+ if c.type == 'global' and c.cate == 'type' then
+ return true
+ end
+ if guide.isLiteral(c) then
+ return true
+ end
+ end
+ return false
+end
+
+---@return boolean
function mt:isNullable()
if self.optional then
return true
@@ -116,7 +137,8 @@ function mt:isNullable()
for _, c in ipairs(self) do
if c.type == 'nil'
or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
- or (c.type == 'global' and c.cate == 'type' and c.name == 'any') then
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'any')
+ or (c.type == 'global' and c.cate == 'type' and c.name == '...') then
return true
end
end
@@ -140,18 +162,25 @@ function mt:setTruthy()
self[c] = nil
goto CONTINUE
end
- if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean')
- or (c.type == 'boolean' or c.type == 'doc.type.boolean') then
+ if c.type == 'global' and c.cate == 'type' and c.name == 'boolean' then
hasBoolean = true
table.remove(self, index)
self[c] = nil
goto CONTINUE
end
+ if c.type == 'boolean' or c.type == 'doc.type.boolean' then
+ if c[1] == false then
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
+ end
+ end
::CONTINUE::
end
if hasBoolean then
- self[#self+1] = vm.declareGlobal('type', 'true')
+ self:merge(vm.declareGlobal('type', 'true'))
end
+ return self
end
---@return vm.node
@@ -165,21 +194,39 @@ function mt:setFalsy()
if c.type == 'nil'
or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
- or (c.type == 'boolean' and c[1] == true)
- or (c.type == 'doc.type.boolean' and c[1] == true) then
+ or (c.type == 'boolean' and c[1] == false)
+ or (c.type == 'doc.type.boolean' and c[1] == false) then
goto CONTINUE
end
- if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean')
- or (c.type == 'boolean' or c.type == 'doc.type.boolean') then
+ if c.type == 'global' and c.cate == 'type' and c.name == 'boolean' then
hasBoolean = true
table.remove(self, index)
self[c] = nil
+ goto CONTINUE
+ end
+ if c.type == 'boolean' or c.type == 'doc.type.boolean' then
+ if c[1] == true then
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
+ end
+ end
+ if (c.type == 'global' and c.cate == 'type') then
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
+ end
+ if guide.isLiteral(c) then
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
end
::CONTINUE::
end
if hasBoolean then
- self[#self+1] = vm.declareGlobal('type', 'false')
+ self:merge(vm.declareGlobal('type', 'false'))
end
+ return self
end
---@param name string
@@ -193,25 +240,144 @@ function mt:remove(name)
or (c.type == name)
or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer'))
or (c.type == 'doc.type.boolean' and name == 'boolean')
+ or (c.type == 'doc.type.boolean' and name == 'true' and c[1] == true)
+ or (c.type == 'doc.type.boolean' and name == 'false' and c[1] == false)
or (c.type == 'doc.type.table' and name == 'table')
or (c.type == 'doc.type.array' and name == 'table')
+ or (c.type == 'doc.type.sign' and name == c.node[1])
or (c.type == 'doc.type.function' and name == 'function') then
table.remove(self, index)
self[c] = nil
end
end
+ return self
+end
+
+---@param name string
+function mt:narrow(name)
+ if name ~= 'nil' and self.optional == true then
+ self.optional = nil
+ end
+ for index = #self, 1, -1 do
+ local c = self[index]
+ if (c.type == name)
+ or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer'))
+ or (c.type == 'doc.type.boolean' and name == 'boolean')
+ or (c.type == 'doc.type.table' and name == 'table')
+ or (c.type == 'doc.type.array' and name == 'table')
+ or (c.type == 'doc.type.sign' and name == c.node[1])
+ or (c.type == 'doc.type.function' and name == 'function') then
+ goto CONTINUE
+ end
+ if c.type == 'global' and c.cate == 'type' then
+ if (c.name == name)
+ or (c.name == 'integer' and name == 'number') then
+ goto CONTINUE
+ end
+ end
+ table.remove(self, index)
+ self[c] = nil
+ ::CONTINUE::
+ end
+ if #self == 0 then
+ self[#self+1] = vm.getGlobal('type', name)
+ end
+ return self
+end
+
+---@param obj vm.object
+function mt:removeObject(obj)
+ for index, c in ipairs(self) do
+ if c == obj then
+ table.remove(self, index)
+ self[c] = nil
+ return
+ end
+ end
end
---@param node vm.node
function mt:removeNode(node)
for _, c in ipairs(node) do
if c.type == 'global' and c.cate == 'type' then
+ ---@cast c vm.global
self:remove(c.name)
+ elseif c.type == 'nil' then
+ self:remove 'nil'
+ elseif c.type == 'boolean'
+ or c.type == 'doc.type.boolean' then
+ if c[1] == true then
+ self:remove 'true'
+ else
+ self:remove 'false'
+ end
+ else
+ ---@cast c -vm.global
+ self:removeObject(c)
end
end
end
----@return fun():vm.object
+---@param name string
+---@return boolean
+function mt:hasType(name)
+ for _, c in ipairs(self) do
+ if c.type == 'global' and c.cate == 'type' and c.name == name then
+ return true
+ end
+ end
+ return false
+end
+
+---@param name string
+---@return boolean
+function mt:hasName(name)
+ if name == 'nil' and self.optional == true then
+ return true
+ end
+ for _, c in ipairs(self) do
+ if c.type == 'global' and c.cate == 'type' and c.name == name then
+ return true
+ end
+ if c.type == name then
+ return true
+ end
+ -- TODO
+ end
+ return false
+end
+
+---@return vm.node
+function mt:asTable()
+ self.optional = nil
+ for index = #self, 1, -1 do
+ local c = self[index]
+ if c.type == 'table'
+ or c.type == 'doc.type.table'
+ or c.type == 'doc.type.array' then
+ goto CONTINUE
+ end
+ if c.type == 'doc.type.sign' then
+ if c.node[1] == 'table'
+ or not guide.isBasicType(c.node[1]) then
+ goto CONTINUE
+ end
+ end
+ if c.type == 'global' and c.cate == 'type' then
+ ---@cast c vm.global
+ if c.name == 'table'
+ or not guide.isBasicType(c.name) then
+ goto CONTINUE
+ end
+ end
+ table.remove(self, index)
+ self[c] = nil
+ ::CONTINUE::
+ end
+ return self
+end
+
+---@return fun():vm.node.object
function mt:eachObject()
local i = 0
return function ()
@@ -226,8 +392,9 @@ function mt:copy()
end
---@param source vm.object
----@param node vm.node | vm.object
+---@param node vm.node | vm.node.object
---@param cover? boolean
+---@return vm.node
function vm.setNode(source, node, cover)
if not node then
if TEST then
@@ -236,23 +403,23 @@ function vm.setNode(source, node, cover)
log.error('Can not set nil node')
end
end
- if source.type == 'global' then
- error('Can not set node to global')
- end
if cover then
+ ---@cast node vm.node
vm.nodeCache[source] = node
- return
+ return node
end
local me = vm.nodeCache[source]
if me then
me:merge(node)
else
if node.type == 'vm.node' then
- vm.nodeCache[source] = node:copy()
+ me = node:copy()
else
- vm.nodeCache[source] = vm.createNode(node)
+ me = vm.createNode(node)
end
+ vm.nodeCache[source] = me
end
+ return me
end
---@param source vm.object
@@ -291,8 +458,8 @@ end
local ID = 0
----@param a? vm.node | vm.object
----@param b? vm.node | vm.object
+---@param a? vm.node | vm.node.object
+---@param b? vm.node | vm.node.object
---@return vm.node
function vm.createNode(a, b)
ID = ID + 1
@@ -315,3 +482,9 @@ files.watch(function (ev, uri)
end
end
end)
+
+ws.watch(function (ev, uri)
+ if ev == 'reload' then
+ vm.clearNodeCache()
+ end
+end)
diff --git a/script/vm/operator.lua b/script/vm/operator.lua
new file mode 100644
index 00000000..9dea01c1
--- /dev/null
+++ b/script/vm/operator.lua
@@ -0,0 +1,368 @@
+---@class vm
+local vm = require 'vm.vm'
+local util = require 'utility'
+local guide = require 'parser.guide'
+local config = require 'config'
+
+vm.UNARY_OP = {
+ 'unm',
+ 'bnot',
+ 'len',
+}
+vm.BINARY_OP = {
+ 'add',
+ 'sub',
+ 'mul',
+ 'div',
+ 'mod',
+ 'pow',
+ 'idiv',
+ 'band',
+ 'bor',
+ 'bxor',
+ 'shl',
+ 'shr',
+ 'concat',
+}
+vm.OTHER_OP = {
+ 'call',
+}
+
+local unaryMap = {
+ ['-'] = 'unm',
+ ['~'] = 'bnot',
+ ['#'] = 'len',
+}
+
+local binaryMap = {
+ ['+'] = 'add',
+ ['-'] = 'sub',
+ ['*'] = 'mul',
+ ['/'] = 'div',
+ ['%'] = 'mod',
+ ['^'] = 'pow',
+ ['//'] = 'idiv',
+ ['&'] = 'band',
+ ['|'] = 'bor',
+ ['~'] = 'bxor',
+ ['<<'] = 'shl',
+ ['>>'] = 'shr',
+ ['..'] = 'concat',
+}
+
+local otherMap = {
+ ['()'] = 'call',
+}
+
+vm.OP_UNARY_MAP = util.revertMap(unaryMap)
+vm.OP_BINARY_MAP = util.revertMap(binaryMap)
+vm.OP_OTHER_MAP = util.revertMap(otherMap)
+
+---@param operators parser.object[]
+---@param op string
+---@param value? parser.object
+---@param result? vm.node
+---@return vm.node?
+local function checkOperators(operators, op, value, result)
+ for _, operator in ipairs(operators) do
+ if operator.op[1] ~= op
+ or not operator.extends then
+ goto CONTINUE
+ end
+ if value and operator.exp then
+ local valueNode = vm.compileNode(value)
+ local expNode = vm.compileNode(operator.exp)
+ local uri = guide.getUri(operator)
+ if not vm.isSubType(uri, valueNode, expNode) then
+ goto CONTINUE
+ end
+ end
+ if not result then
+ result = vm.createNode()
+ end
+ result:merge(vm.compileNode(operator.extends))
+ ::CONTINUE::
+ end
+ return result
+end
+
+---@param op string
+---@param exp parser.object
+---@param value? parser.object
+---@return vm.node?
+function vm.runOperator(op, exp, value)
+ local uri = guide.getUri(exp)
+ local node = vm.compileNode(exp)
+ local result
+ for c in node:eachObject() do
+ if c.type == 'string'
+ or c.type == 'doc.type.string' then
+ c = vm.declareGlobal('type', 'string')
+ end
+ if c.type == 'global' and c.cate == 'type' then
+ ---@cast c vm.global
+ for _, set in ipairs(c:getSets(uri)) do
+ if set.operators and #set.operators > 0 then
+ result = checkOperators(set.operators, op, value, result)
+ end
+ end
+ end
+ end
+ return result
+end
+
+vm.unarySwich = util.switch()
+ : case 'not'
+ : call(function (source)
+ local result = vm.testCondition(source[1])
+ if result == nil then
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
+ else
+ vm.setNode(source, {
+ type = 'boolean',
+ start = source.start,
+ finish = source.finish,
+ parent = source,
+ [1] = not result,
+ })
+ end
+ end)
+ : case '#'
+ : call(function (source)
+ local node = vm.runOperator('len', source[1])
+ vm.setNode(source, node or vm.declareGlobal('type', 'integer'))
+ end)
+ : case '-'
+ : call(function (source)
+ local v = vm.getNumber(source[1])
+ if v == nil then
+ local uri = guide.getUri(source)
+ local infer = vm.getInfer(source[1])
+ if infer:hasType(uri, 'integer') then
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
+ elseif infer:hasType(uri, 'number') then
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
+ else
+ local node = vm.runOperator('unm', source[1])
+ vm.setNode(source, node or vm.declareGlobal('type', 'number'))
+ end
+ else
+ vm.setNode(source, {
+ type = 'number',
+ start = source.start,
+ finish = source.finish,
+ parent = source,
+ [1] = -v,
+ })
+ end
+ end)
+ : case '~'
+ : call(function (source)
+ local v = vm.getInteger(source[1])
+ if v == nil then
+ local node = vm.runOperator('bnot', source[1])
+ vm.setNode(source, node or vm.declareGlobal('type', 'integer'))
+ else
+ vm.setNode(source, {
+ type = 'integer',
+ start = source.start,
+ finish = source.finish,
+ parent = source,
+ [1] = ~v,
+ })
+ end
+ end)
+
+vm.binarySwitch = util.switch()
+ : case 'and'
+ : call(function (source)
+ local node1 = vm.compileNode(source[1])
+ local node2 = vm.compileNode(source[2])
+ local r1 = vm.testCondition(source[1])
+ if r1 == true then
+ vm.setNode(source, node2)
+ elseif r1 == false then
+ vm.setNode(source, node1)
+ else
+ local node = node1:copy():setFalsy():merge(node2)
+ vm.setNode(source, node)
+ end
+ end)
+ : case 'or'
+ : call(function (source)
+ local node1 = vm.compileNode(source[1])
+ local node2 = vm.compileNode(source[2])
+ local r1 = vm.testCondition(source[1])
+ if r1 == true then
+ vm.setNode(source, node1)
+ elseif r1 == false then
+ vm.setNode(source, node2)
+ else
+ local node = node1:copy():setTruthy():merge(node2)
+ vm.setNode(source, node)
+ end
+ end)
+ : case '=='
+ : case '~='
+ : call(function (source)
+ local result = vm.equal(source[1], source[2])
+ if result == nil then
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
+ else
+ if source.op.type == '~=' then
+ result = not result
+ end
+ vm.setNode(source, {
+ type = 'boolean',
+ start = source.start,
+ finish = source.finish,
+ parent = source,
+ [1] = result,
+ })
+ end
+ end)
+ : case '<<'
+ : case '>>'
+ : case '&'
+ : case '|'
+ : case '~'
+ : call(function (source)
+ local a = vm.getInteger(source[1])
+ local b = vm.getInteger(source[2])
+ local op = source.op.type
+ if a and b then
+ local result = op == '<<' and a << b
+ or op == '>>' and a >> b
+ or op == '&' and a & b
+ or op == '|' and a | b
+ or op == '~' and a ~ b
+ vm.setNode(source, {
+ type = 'integer',
+ start = source.start,
+ finish = source.finish,
+ parent = source,
+ [1] = result,
+ })
+ else
+ local node = vm.runOperator(binaryMap[op], source[1], source[2])
+ vm.setNode(source, node or vm.declareGlobal('type', 'integer'))
+ end
+ end)
+ : case '+'
+ : case '-'
+ : case '*'
+ : case '/'
+ : case '%'
+ : case '//'
+ : case '^'
+ : call(function (source)
+ local a = vm.getNumber(source[1])
+ local b = vm.getNumber(source[2])
+ local op = source.op.type
+ local zero = b == 0
+ and ( op == '%'
+ or op == '/'
+ or op == '//'
+ )
+ if a and b and not zero then
+ local result = op == '+' and a + b
+ or op == '-' and a - b
+ or op == '*' and a * b
+ or op == '/' and a / b
+ or op == '%' and a % b
+ or op == '//' and a // b
+ or op == '^' and a ^ b
+ vm.setNode(source, {
+ type = math.type(result) == 'integer' and 'integer' or 'number',
+ start = source.start,
+ finish = source.finish,
+ parent = source,
+ [1] = result,
+ })
+ else
+ local node = vm.runOperator(binaryMap[op], source[1], source[2])
+ if node then
+ vm.setNode(source, node)
+ return
+ end
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '//'
+ or op == '%' then
+ local uri = guide.getUri(source)
+ local infer1 = vm.getInfer(source[1])
+ local infer2 = vm.getInfer(source[2])
+ if infer1:hasType(uri, 'integer')
+ or infer2:hasType(uri, 'integer') then
+ if not infer1:hasType(uri, 'number')
+ and not infer2:hasType(uri, 'number') then
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
+ return
+ end
+ end
+ end
+ vm.setNode(source, node or vm.declareGlobal('type', 'number'))
+ end
+ end)
+ : case '..'
+ : call(function (source)
+ local a = vm.getString(source[1])
+ or vm.getNumber(source[1])
+ local b = vm.getString(source[2])
+ or vm.getNumber(source[2])
+ if a and b then
+ if type(a) == 'number' or type(b) == 'number' then
+ local uri = guide.getUri(source)
+ local version = config.get(uri, 'Lua.runtime.version')
+ if math.tointeger(a) and math.type(a) == 'float' then
+ if version == 'Lua 5.3' or version == 'Lua 5.4' then
+ a = ('%.1f'):format(a)
+ else
+ a = ('%.0f'):format(a)
+ end
+ end
+ if math.tointeger(b) and math.type(b) == 'float' then
+ if version == 'Lua 5.3' or version == 'Lua 5.4' then
+ b = ('%.1f'):format(b)
+ else
+ b = ('%.0f'):format(b)
+ end
+ end
+ end
+ vm.setNode(source, {
+ type = 'string',
+ start = source.start,
+ finish = source.finish,
+ parent = source,
+ [1] = a .. b,
+ })
+ else
+ local node = vm.runOperator(binaryMap[source.op.type], source[1], source[2])
+ vm.setNode(source, node or vm.declareGlobal('type', 'string'))
+ end
+ end)
+ : case '>'
+ : case '<'
+ : case '>='
+ : case '<='
+ : call(function (source)
+ local a = vm.getNumber(source[1])
+ local b = vm.getNumber(source[2])
+ if a and b then
+ local op = source.op.type
+ local result = op == '>' and a > b
+ or op == '<' and a < b
+ or op == '>=' and a >= b
+ or op == '<=' and a <= b
+ vm.setNode(source, {
+ type = 'boolean',
+ start = source.start,
+ finish = source.finish,
+ parent = source,
+ [1] =result,
+ })
+ else
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
+ end
+ end)
diff --git a/script/vm/ref.lua b/script/vm/ref.lua
index 545c294a..0135d11f 100644
--- a/script/vm/ref.lua
+++ b/script/vm/ref.lua
@@ -9,58 +9,7 @@ local lang = require 'language'
local simpleSwitch
-local function searchGetLocal(source, node, pushResult)
- local key = guide.getKeyName(source)
- for _, ref in ipairs(node.node.ref) do
- if ref.type == 'getlocal'
- and ref.next
- and guide.getKeyName(ref.next) == key then
- pushResult(ref.next)
- end
- end
-end
-
simpleSwitch = util.switch()
- : case 'local'
- : call(function (source, pushResult)
- if source.ref then
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal'
- or ref.type == 'getlocal' then
- pushResult(ref)
- end
- end
- end
- end)
- : case 'getlocal'
- : case 'setlocal'
- : call(function (source, pushResult)
- simpleSwitch('local', source.node, pushResult)
- end)
- : case 'field'
- : call(function (source, pushResult)
- local parent = source.parent
- if parent.type ~= 'tablefield' then
- simpleSwitch(parent.type, parent, pushResult)
- end
- end)
- : case 'setfield'
- : case 'getfield'
- : call(function (source, pushResult)
- local node = source.node
- if node.type == 'getlocal' then
- searchGetLocal(source, node, pushResult)
- return
- end
- end)
- : case 'getindex'
- : case 'setindex'
- : call(function (source, pushResult)
- local node = source.node
- if node.type == 'getlocal' then
- searchGetLocal(source, node, pushResult)
- end
- end)
: case 'goto'
: call(function (source, pushResult)
if source.node then
@@ -142,21 +91,21 @@ local function searchField(source, pushResult, defMap, fileNotify)
return
end
---@async
- guide.eachSourceType(state.ast, 'getfield', function (src)
+ guide.eachSourceTypes(state.ast, {'getfield', 'setfield'}, function (src)
if src.field and src.field[1] == key then
checkDef(src)
await.delay()
end
end)
---@async
- guide.eachSourceType(state.ast, 'getmethod', function (src)
+ guide.eachSourceTypes(state.ast, {'getmethod', 'setmethod'}, function (src)
if src.method and src.method[1] == key then
checkDef(src)
await.delay()
end
end)
---@async
- guide.eachSourceType(state.ast, 'getindex', function (src)
+ guide.eachSourceTypes(state.ast, {'getindex', 'setindex'}, function (src)
if src.index and src.index.type == 'string' and src.index[1] == key then
checkDef(src)
await.delay()
@@ -240,19 +189,24 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
- local idSources = vm.getLocalSources(source)
- if not idSources then
- return
+ local sourceSets = vm.getLocalSourcesSets(source)
+ if sourceSets then
+ for _, src in ipairs(sourceSets) do
+ pushResult(src)
+ end
end
- for _, src in ipairs(idSources) do
- pushResult(src)
+ local sourceGets = vm.getLocalSourcesGets(source)
+ if sourceGets then
+ for _, src in ipairs(sourceGets) do
+ pushResult(src)
+ end
end
end
---@async
---@param source parser.object
---@param pushResult fun(src: parser.object)
----@param fileNotify fun(uri: uri): boolean
+---@param fileNotify? fun(uri: uri): boolean
function searchByParentNode(source, pushResult, defMap, fileNotify)
nodeSwitch(source.type, source, pushResult, defMap, fileNotify)
end
@@ -279,10 +233,22 @@ local function searchByDef(source, pushResult)
defMap[source] = true
return defMap
end
- local defs = vm.getDefs(source)
- for _, def in ipairs(defs) do
- pushResult(def)
- defMap[def] = true
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ defMap[source] = true
+ if guide.isSet(source) then
+ local defs = vm.getDefs(source)
+ for _, def in ipairs(defs) do
+ pushResult(def)
+ end
+ else
+ local defs = vm.getDefs(source)
+ for _, def in ipairs(defs) do
+ pushResult(def)
+ defMap[def] = true
+ end
end
return defMap
end
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 9fe0f172..2f047983 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -2,257 +2,22 @@
local vm = require 'vm.vm'
local guide = require 'parser.guide'
+---@alias vm.runner.callback fun(src: parser.object, node?: vm.node)
+
---@class vm.runner
----@field loc parser.object
----@field mainBlock parser.object
----@field blocks table<parser.object, true>
----@field steps vm.runner.step[]
+---@field _loc parser.object
+---@field _casts parser.object[]
+---@field _callback vm.runner.callback
+---@field _mark table
+---@field _has table<parser.object, true>
+---@field _main parser.object
local mt = {}
mt.__index = mt
-mt.index = 1
-
----@class parser.object
----@field _casts parser.object[]
-
----@class vm.runner.step
----@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast'
----@field pos integer
----@field order? integer
----@field node? vm.node
----@field object? parser.object
----@field name? string
----@field cast? parser.object
----@field tag? string
----@field copy? boolean
----@field new? boolean
----@field ref1? vm.runner.step
----@field ref2? vm.runner.step
-
----@param filter parser.object
----@param outStep vm.runner.step
----@param blockStep vm.runner.step
-function mt:_compileNarrowByFilter(filter, outStep, blockStep)
- if not filter then
- return
- end
- if filter.type == 'paren' then
- if filter.exp then
- self:_compileNarrowByFilter(filter.exp, outStep, blockStep)
- end
- return
- end
- if filter.type == 'unary' then
- if not filter.op
- or not filter[1] then
- return
- end
- if filter.op.type == 'not' then
- local exp = filter[1]
- if exp.type == 'getlocal' and exp.node == self.loc then
- self.steps[#self.steps+1] = {
- type = 'falsy',
- pos = filter.finish,
- new = true,
- }
- self.steps[#self.steps+1] = {
- type = 'truthy',
- pos = filter.finish,
- ref1 = outStep,
- }
- end
- end
- elseif filter.type == 'binary' then
- if not filter.op
- or not filter[1]
- or not filter[2] then
- return
- end
- if filter.op.type == 'and' then
- local dummyStep = {
- type = 'save',
- copy = true,
- ref1 = outStep,
- pos = filter.start - 1,
- }
- self.steps[#self.steps+1] = dummyStep
- self:_compileNarrowByFilter(filter[1], dummyStep, blockStep)
- self:_compileNarrowByFilter(filter[2], dummyStep, blockStep)
- end
- if filter.op.type == 'or' then
- self:_compileNarrowByFilter(filter[1], outStep, blockStep)
- local dummyStep = {
- type = 'push',
- copy = true,
- ref1 = outStep,
- pos = filter.op.finish,
- }
- self.steps[#self.steps+1] = dummyStep
- self:_compileNarrowByFilter(filter[2], outStep, dummyStep)
- self.steps[#self.steps+1] = {
- type = 'push',
- tag = 'or reset',
- ref1 = blockStep,
- pos = filter.finish,
- }
- end
- if filter.op.type == '=='
- or filter.op.type == '~=' then
- local loc, exp
- for i = 1, 2 do
- loc = filter[i]
- if loc.type == 'getlocal' and loc.node == self.loc then
- exp = filter[i % 2 + 1]
- break
- end
- end
- if not loc or not exp then
- return
- end
- if guide.isLiteral(exp) then
- if filter.op.type == '==' then
- self.steps[#self.steps+1] = {
- type = 'remove',
- name = exp.type,
- pos = filter.finish,
- ref1 = outStep,
- }
- self.steps[#self.steps+1] = {
- type = 'as',
- name = exp.type,
- pos = filter.finish,
- new = true,
- }
- end
- if filter.op.type == '~=' then
- self.steps[#self.steps+1] = {
- type = 'as',
- name = exp.type,
- pos = filter.finish,
- ref1 = outStep,
- }
- self.steps[#self.steps+1] = {
- type = 'remove',
- name = exp.type,
- pos = filter.finish,
- new = true,
- }
- end
- end
- end
- else
- if filter.type == 'getlocal' and filter.node == self.loc then
- self.steps[#self.steps+1] = {
- type = 'truthy',
- pos = filter.finish,
- new = true,
- }
- self.steps[#self.steps+1] = {
- type = 'falsy',
- pos = filter.finish,
- ref1 = outStep,
- }
- end
- end
-end
-
----@param block parser.object
-function mt:_compileBlock(block)
- if self.blocks[block] then
- return
- end
- self.blocks[block] = true
- if block == self.mainBlock then
- return
- end
-
- local parentBlock = guide.getParentBlock(block)
- self:_compileBlock(parentBlock)
-
- if block.type == 'if' then
- ---@type vm.runner.step[]
- local finals = {}
- for i, childBlock in ipairs(block) do
- local blockStep = {
- type = 'save',
- tag = 'block',
- copy = true,
- pos = childBlock.start,
- }
- local outStep = {
- type = 'save',
- tag = 'out',
- copy = true,
- pos = childBlock.start,
- }
- self.steps[#self.steps+1] = blockStep
- self.steps[#self.steps+1] = outStep
- self.steps[#self.steps+1] = {
- type = 'push',
- ref1 = blockStep,
- pos = childBlock.start,
- }
- self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep)
- if not childBlock.hasReturn
- and not childBlock.hasGoTo
- and not childBlock.hasBreak then
- local finalStep = {
- type = 'save',
- pos = childBlock.finish,
- tag = 'final #' .. i,
- }
- finals[#finals+1] = finalStep
- self.steps[#self.steps+1] = finalStep
- end
- self.steps[#self.steps+1] = {
- type = 'push',
- tag = 'reset child',
- ref1 = outStep,
- pos = childBlock.finish,
- }
- end
- self.steps[#self.steps+1] = {
- type = 'push',
- tag = 'reset if',
- pos = block.finish,
- copy = true,
- }
- for _, final in ipairs(finals) do
- self.steps[#self.steps+1] = {
- type = 'merge',
- ref2 = final,
- pos = block.finish,
- }
- end
- end
-
- if block.type == 'function'
- or block.type == 'while'
- or block.type == 'loop'
- or block.type == 'in'
- or block.type == 'repeat'
- or block.type == 'for' then
- local savePoint = {
- type = 'save',
- copy = true,
- pos = block.start,
- }
- self.steps[#self.steps+1] = {
- type = 'push',
- copy = true,
- pos = block.start,
- }
- self.steps[#self.steps+1] = savePoint
- self.steps[#self.steps+1] = {
- type = 'push',
- pos = block.finish,
- ref1 = savePoint,
- }
- end
-end
+mt._index = 1
---@return parser.object[]
function mt:_getCasts()
- local root = guide.getRoot(self.loc)
+ local root = guide.getRoot(self._loc)
if not root._casts then
root._casts = {}
local docs = root.docs
@@ -265,180 +30,332 @@ function mt:_getCasts()
return root._casts
end
-function mt:_preCompile()
- local startPos = self.loc.start
- local finishPos = 0
-
- for _, ref in ipairs(self.loc.ref) do
- self.steps[#self.steps+1] = {
- type = 'object',
- object = ref,
- pos = ref.range or ref.start,
- }
- if ref.start > finishPos then
- finishPos = ref.start
+---@param obj parser.object
+function mt:_markHas(obj)
+ while true do
+ if self._has[obj] then
+ return
end
- local block = guide.getParentBlock(ref)
- self:_compileBlock(block)
+ self._has[obj] = true
+ if obj == self._main then
+ return
+ end
+ obj = obj.parent
end
+end
- for i, step in ipairs(self.steps) do
- if step.type ~= 'object' then
- step.order = i
+function mt:_collect()
+ local startPos = self._loc.start
+ local finishPos = 0
+
+ for _, ref in ipairs(self._loc.ref) do
+ if ref.type == 'getlocal'
+ or ref.type == 'setlocal' then
+ self:_markHas(ref)
+ if ref.finish > finishPos then
+ finishPos = ref.finish
+ end
end
end
local casts = self:_getCasts()
for _, cast in ipairs(casts) do
- if cast.loc[1] == self.loc[1]
+ if cast.loc[1] == self._loc[1]
and cast.start > startPos
and cast.finish < finishPos
- and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then
- self.steps[#self.steps+1] = {
- type = 'cast',
- cast = cast,
- pos = cast.start,
- }
+ and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then
+ self._casts[#self._casts+1] = cast
end
end
-
- table.sort(self.steps, function (a, b)
- if a.pos == b.pos then
- return (a.order or 0) < (b.order or 0)
- else
- return a.pos < b.pos
- end
- end)
end
----@param loc parser.object
----@param node vm.node
+---@param pos integer
+---@param topNode vm.node
---@return vm.node
-local function checkAssert(loc, node)
- local parent = loc.parent
- if parent.type == 'binary' then
- if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then
- local exp
- for i = 1, 2 do
- if parent[i] == loc then
- exp = parent[i % 2 + 1]
+function mt:_fastWardCasts(pos, topNode)
+ for i = self._index, #self._casts do
+ local action = self._casts[i]
+ if action.start > pos then
+ self._index = i
+ return topNode
+ end
+ topNode = topNode:copy()
+ for _, cast in ipairs(action.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ topNode:addOptional()
end
- end
- if exp and guide.isLiteral(exp) then
- local callargs = parent.parent
- if callargs.type == 'callargs'
- and callargs.parent.node.special == 'assert'
- and callargs[1] == parent then
- if parent.op.type == '~=' then
- node:remove(exp.type)
- end
- if parent.op.type == '==' then
- node = vm.compileNode(exp)
- end
+ if cast.extends then
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ topNode:removeOptional()
+ end
+ if cast.extends then
+ topNode:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ topNode:clear()
+ topNode:merge(vm.compileNode(cast.extends))
end
end
end
end
- if parent.type == 'callargs'
- and parent.parent.node.special == 'assert'
- and parent[1] == loc then
- node:setTruthy()
- end
- return node
+ self._index = self._index + 1
+ return topNode
end
----@param callback fun(src: parser.object, node: vm.node)
-function mt:launch(callback)
- local topNode = vm.getNode(self.loc):copy()
- for _, step in ipairs(self.steps) do
- local node = step.ref1 and step.ref1.node or topNode
- if step.type == 'truthy' then
- if step.new then
- node = node:copy()
- topNode = node
- end
- node:setTruthy()
- elseif step.type == 'falsy' then
- if step.new then
- node = node:copy()
- topNode = node
- end
- node:setFalsy()
- elseif step.type == 'as' then
- if step.new then
- topNode = vm.createNode(vm.getGlobal('type', step.name))
- else
- node:clear()
- node:merge(vm.getGlobal('type', step.name))
- end
- elseif step.type == 'add' then
- if step.new then
- node = node:copy()
- topNode = node
- end
- node:merge(vm.getGlobal('type', step.name))
- elseif step.type == 'remove' then
- if step.new then
- node = node:copy()
- topNode = node
- end
- node:remove(step.name)
- elseif step.type == 'object' then
- topNode = callback(step.object, node) or node
- if step.object.type == 'getlocal' then
- topNode = checkAssert(step.object, node)
+---@param action parser.object
+---@param topNode vm.node
+---@param outNode? vm.node
+---@return vm.node topNode
+---@return vm.node outNode
+function mt:_lookIntoChild(action, topNode, outNode)
+ if not self._has[action]
+ or self._mark[action] then
+ return topNode, topNode or outNode
+ end
+ self._mark[action] = true
+ topNode = self:_fastWardCasts(action.start, topNode)
+ if action.type == 'getlocal' then
+ if action.node == self._loc then
+ self._callback(action, topNode)
+ if outNode then
+ topNode = topNode:copy():setTruthy()
+ outNode = outNode:copy():setFalsy()
end
- elseif step.type == 'save' then
- if step.copy then
- node = node:copy()
+ end
+ elseif action.type == 'function' then
+ self:_lookIntoBlock(action, topNode:copy())
+ elseif action.type == 'unary' then
+ if not action[1] then
+ goto RETURN
+ end
+ if action.op.type == 'not' then
+ outNode = outNode or topNode:copy()
+ outNode, topNode = self:_lookIntoChild(action[1], topNode, outNode)
+ outNode = outNode:copy()
+ end
+ elseif action.type == 'binary' then
+ if not action[1] or not action[2] then
+ goto RETURN
+ end
+ if action.op.type == 'and' then
+ topNode = self:_lookIntoChild(action[1], topNode, topNode:copy())
+ topNode = self:_lookIntoChild(action[2], topNode, topNode:copy())
+ elseif action.op.type == 'or' then
+ outNode = outNode or topNode:copy()
+ local topNode1, outNode1 = self:_lookIntoChild(action[1], topNode, outNode)
+ local topNode2, outNode2 = self:_lookIntoChild(action[2], outNode1, outNode1:copy())
+ topNode = vm.createNode(topNode1, topNode2)
+ outNode = outNode2:copy()
+ elseif action.op.type == '=='
+ or action.op.type == '~=' then
+ local handler, checker
+ for i = 1, 2 do
+ if guide.isLiteral(action[i]) then
+ checker = action[i]
+ handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
+ end
end
- step.node = node
- elseif step.type == 'push' then
- if step.copy then
- node = node:copy()
+ if not handler then
+ goto RETURN
end
- topNode = node
- elseif step.type == 'merge' then
- node:merge(step.ref2.node)
- elseif step.type == 'cast' then
- topNode = node:copy()
- for _, cast in ipairs(step.cast.casts) do
- if cast.mode == '+' then
- if cast.optional then
- topNode:addOptional()
+ if handler.type == 'getlocal'
+ and handler.node == self._loc then
+ -- if x == y then
+ topNode = self:_lookIntoChild(handler, topNode, outNode)
+ local checkerNode = vm.compileNode(checker)
+ if action.op.type == '==' then
+ topNode = checkerNode
+ if outNode then
+ outNode:removeNode(topNode)
end
- if cast.extends then
- topNode:merge(vm.compileNode(cast.extends))
- end
- elseif cast.mode == '-' then
- if cast.optional then
- topNode:removeOptional()
+ else
+ topNode:removeNode(checkerNode)
+ if outNode then
+ outNode = checkerNode
end
- if cast.extends then
- topNode:removeNode(vm.compileNode(cast.extends))
+ end
+ elseif handler.type == 'call'
+ and checker.type == 'string'
+ and handler.node.special == 'type'
+ and handler.args
+ and handler.args[1]
+ and handler.args[1].type == 'getlocal'
+ and handler.args[1].node == self._loc then
+ -- if type(x) == 'string' then
+ self:_lookIntoChild(handler, topNode:copy())
+ if action.op.type == '==' then
+ topNode:narrow(checker[1])
+ if outNode then
+ outNode:remove(checker[1])
end
else
- if cast.extends then
- topNode:clear()
- topNode:merge(vm.compileNode(cast.extends))
+ topNode:remove(checker[1])
+ if outNode then
+ outNode:narrow(checker[1])
+ end
+ end
+ elseif handler.type == 'getlocal'
+ and checker.type == 'string' then
+ local nodeValue = vm.getObjectValue(handler.node)
+ if nodeValue
+ and nodeValue.type == 'select'
+ and nodeValue.sindex == 1 then
+ local call = nodeValue.vararg
+ if call
+ and call.type == 'call'
+ and call.node.special == 'type'
+ and call.args
+ and call.args[1]
+ and call.args[1].type == 'getlocal'
+ and call.args[1].node == self._loc then
+ -- `local tp = type(x);if tp == 'string' then`
+ if action.op.type == '==' then
+ topNode:narrow(checker[1])
+ if outNode then
+ outNode:remove(checker[1])
+ end
+ else
+ topNode:remove(checker[1])
+ if outNode then
+ outNode:narrow(checker[1])
+ end
+ end
+ end
+ end
+ end
+ end
+ elseif action.type == 'loop'
+ or action.type == 'in'
+ or action.type == 'repeat'
+ or action.type == 'for' then
+ topNode = self:_lookIntoBlock(action, topNode:copy())
+ elseif action.type == 'while' then
+ local blockNode, mainNode
+ if action.filter then
+ blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy())
+ else
+ blockNode = topNode:copy()
+ mainNode = topNode:copy()
+ end
+ blockNode = self:_lookIntoBlock(action, blockNode:copy())
+ topNode = mainNode:merge(blockNode)
+ if action.filter then
+ -- look into filter again
+ guide.eachSource(action.filter, function (src)
+ self._mark[src] = nil
+ end)
+ blockNode, topNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy())
+ end
+ elseif action.type == 'if' then
+ local hasElse
+ local mainNode = topNode:copy()
+ local blockNodes = {}
+ for _, subBlock in ipairs(action) do
+ local blockNode = mainNode:copy()
+ if subBlock.filter then
+ blockNode, mainNode = self:_lookIntoChild(subBlock.filter, blockNode, mainNode)
+ else
+ hasElse = true
+ mainNode:clear()
+ end
+ blockNode = self:_lookIntoBlock(subBlock, blockNode:copy())
+ local neverReturn = subBlock.hasReturn
+ or subBlock.hasGoTo
+ or subBlock.hasBreak
+ or subBlock.hasError
+ if not neverReturn then
+ blockNodes[#blockNodes+1] = blockNode
+ end
+ end
+ if not hasElse and not topNode:hasKnownType() then
+ mainNode:merge(vm.declareGlobal('type', 'unknown'))
+ end
+ for _, blockNode in ipairs(blockNodes) do
+ mainNode:merge(blockNode)
+ end
+ topNode = mainNode
+ elseif action.type == 'call' then
+ if action.node.special == 'assert' and action.args and action.args[1] then
+ topNode = self:_lookIntoChild(action.args[1], topNode, topNode:copy())
+ end
+ elseif action.type == 'paren' then
+ topNode, outNode = self:_lookIntoChild(action.exp, topNode, outNode)
+ elseif action.type == 'setlocal' then
+ if action.node == self._loc then
+ if action.value then
+ self:_lookIntoChild(action.value, topNode)
+ end
+ topNode = self._callback(action, topNode)
+ end
+ elseif action.type == 'local' then
+ if action.value
+ and action.ref
+ and action.value.type == 'select' then
+ local index = action.value.sindex
+ local call = action.value.vararg
+ if index == 1
+ and call.type == 'call'
+ and call.node
+ and call.node.special == 'type'
+ and call.args then
+ local getLoc = call.args[1]
+ if getLoc
+ and getLoc.type == 'getlocal'
+ and getLoc.node == self._loc then
+ for _, ref in ipairs(action.ref) do
+ self:_markHas(ref)
end
end
end
end
end
+ ::RETURN::
+ guide.eachChild(action, function (src)
+ if self._has[src] then
+ self:_lookIntoChild(src, topNode)
+ end
+ end)
+ return topNode, outNode or topNode
+end
+
+---@param block parser.object
+---@param topNode vm.node
+---@return vm.node topNode
+function mt:_lookIntoBlock(block, topNode)
+ if not self._has[block] then
+ return topNode
+ end
+ for _, action in ipairs(block) do
+ if self._has[action] then
+ topNode = self:_lookIntoChild(action, topNode)
+ end
+ end
+ topNode = self:_fastWardCasts(block.finish, topNode)
+ return topNode
end
---@param loc parser.object
----@return vm.runner
-function vm.createRunner(loc)
+---@param callback vm.runner.callback
+function vm.launchRunner(loc, callback)
+ local main = guide.getParentBlock(loc)
+ if not main then
+ return
+ end
local self = setmetatable({
- loc = loc,
- mainBlock = guide.getParentBlock(loc),
- blocks = {},
- steps = {},
+ _loc = loc,
+ _casts = {},
+ _mark = {},
+ _has = {},
+ _main = main,
+ _callback = callback,
}, mt)
- self:_preCompile()
+ self:_collect()
- return self
+ self:_lookIntoBlock(main, vm.getNode(loc):copy())
end
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index fe112bc2..7c95fd08 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -17,14 +17,14 @@ end
---@param uri uri
---@param args parser.object
---@param removeGeneric true?
----@return table<string, vm.node>
+---@return table<string, vm.node>?
function mt:resolve(uri, args, removeGeneric)
if not args then
return nil
end
local resolved = {}
- ---@param object parser.object
+ ---@param object vm.node.object
---@param node vm.node
local function resolve(object, node)
if object.type == 'doc.generic.name' then
@@ -33,6 +33,7 @@ function mt:resolve(uri, args, removeGeneric)
-- 'number' -> `T`
for n in node:eachObject() do
if n.type == 'string' then
+ ---@cast n parser.object
local type = vm.declareGlobal('type', n[1], guide.getUri(n))
resolved[key] = vm.createNode(type, resolved[key])
end
@@ -50,40 +51,59 @@ function mt:resolve(uri, args, removeGeneric)
end
if n.type == 'doc.type.table' then
-- { [integer]: number } -> T[]
- local tvalueNode = vm.getTableValue(uri, node, 'integer')
+ local tvalueNode = vm.getTableValue(uri, node, 'integer', true)
if tvalueNode then
resolve(object.node, tvalueNode)
end
end
if n.type == 'global' and n.cate == 'type' then
-- ---@field [integer]: number -> T[]
+ ---@cast n vm.global
vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), false, function (field)
resolve(object.node, vm.compileNode(field.extends))
end)
end
+ if n.type == 'table' and #n >= 1 then
+ -- { x } / { ... } -> T[]
+ resolve(object.node, vm.compileNode(n[1]))
+ end
end
end
if object.type == 'doc.type.table' then
for _, ufield in ipairs(object.fields) do
local ufieldNode = vm.compileNode(ufield.name)
local uvalueNode = vm.compileNode(ufield.extends)
- if ufieldNode:get(1).type == 'doc.generic.name' and uvalueNode:get(1).type == 'doc.generic.name' then
+ local firstField = ufieldNode:get(1)
+ local firstValue = uvalueNode:get(1)
+ if not firstField or not firstValue then
+ goto CONTINUE
+ end
+ if firstField.type == 'doc.generic.name' and firstValue.type == 'doc.generic.name' then
-- { [number]: number} -> { [K]: V }
- local tfieldNode = vm.getTableKey(uri, node, 'any')
- local tvalueNode = vm.getTableValue(uri, node, 'any')
- resolve(ufieldNode:get(1), tfieldNode)
- resolve(uvalueNode:get(1), tvalueNode)
+ local tfieldNode = vm.getTableKey(uri, node, 'any', true)
+ local tvalueNode = vm.getTableValue(uri, node, 'any', true)
+ if tfieldNode then
+ resolve(firstField, tfieldNode)
+ end
+ if tvalueNode then
+ resolve(firstValue, tvalueNode)
+ end
else
if ufieldNode:get(1).type == 'doc.generic.name' then
-- { [number]: number}|number[] -> { [K]: number }
- local tnode = vm.getTableKey(uri, node, uvalueNode)
- resolve(ufieldNode:get(1), tnode)
+ local tnode = vm.getTableKey(uri, node, uvalueNode, true)
+ if tnode then
+ resolve(firstField, tnode)
+ end
elseif uvalueNode:get(1).type == 'doc.generic.name' then
-- { [number]: number}|number[] -> { [number]: V }
- local tnode = vm.getTableValue(uri, node, ufieldNode)
- resolve(uvalueNode:get(1), tnode)
+ local tnode = vm.getTableValue(uri, node, ufieldNode, true)
+ if tnode then
+ resolve(firstValue, tnode)
+ end
end
end
+ ::CONTINUE::
end
end
end
@@ -102,6 +122,7 @@ function mt:resolve(uri, args, removeGeneric)
if obj.type == 'doc.type.table'
or obj.type == 'doc.type.function'
or obj.type == 'doc.type.array' then
+ ---@cast obj parser.object
local hasGeneric
guide.eachSourceType(obj, 'doc.generic.name', function (src)
hasGeneric = true
@@ -111,7 +132,7 @@ function mt:resolve(uri, args, removeGeneric)
goto CONTINUE
end
end
- local view = vm.viewObject(obj)
+ local view = vm.viewObject(obj, uri)
if view then
knownTypes[view] = true
end
@@ -122,21 +143,31 @@ function mt:resolve(uri, args, removeGeneric)
-- remove un-generic type
---@param argNode vm.node
+ ---@param sign vm.node
---@param knownTypes table<string, true>
---@return vm.node
- local function buildArgNode(argNode, knownTypes)
+ local function buildArgNode(argNode, sign, knownTypes)
local newArgNode = vm.createNode()
+ local needRemoveNil = sign:hasFalsy()
for n in argNode:eachObject() do
- if argNode:hasFalsy() then
- goto CONTINUE
+ if needRemoveNil then
+ if n.type == 'nil' then
+ goto CONTINUE
+ end
+ if n.type == 'global' and n.cate == 'type' and n.name == 'nil' then
+ goto CONTINUE
+ end
end
- local view = vm.viewObject(n)
+ local view = vm.viewObject(n, uri)
if knownTypes[view] then
goto CONTINUE
end
newArgNode:merge(n)
::CONTINUE::
end
+ if not needRemoveNil and argNode:isOptional() then
+ newArgNode:addOptional()
+ end
return newArgNode
end
@@ -158,7 +189,7 @@ function mt:resolve(uri, args, removeGeneric)
local argNode = vm.compileNode(arg)
local knownTypes, genericNames = getSignInfo(sign)
if not isAllResolved(genericNames) then
- local newArgNode = buildArgNode(argNode, knownTypes)
+ local newArgNode = buildArgNode(argNode,sign, knownTypes)
for n in sign:eachObject() do
resolve(n, newArgNode)
end
diff --git a/script/vm/type.lua b/script/vm/type.lua
index c3264993..d112be2c 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -1,67 +1,252 @@
---@class vm
local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local config = require 'config.config'
+local util = require 'utility'
----@param uri uri
----@param child vm.node|string
----@param parent vm.node|string
----@param mark? table
----@return boolean
-function vm.isSubType(uri, child, parent, mark)
- if type(parent) == 'string' then
- parent = vm.createNode(vm.getGlobal('type', parent))
+---@param object vm.node.object
+---@return string?
+local function getNodeName(object)
+ if object.type == 'global' and object.cate == 'type' then
+ ---@cast object vm.global
+ return object.name
end
- if type(child) == 'string' then
- child = vm.createNode(vm.getGlobal('type', child))
+ if object.type == 'nil'
+ or object.type == 'boolean'
+ or object.type == 'number'
+ or object.type == 'string'
+ or object.type == 'table'
+ or object.type == 'function'
+ or object.type == 'integer' then
+ return object.type
+ end
+ if object.type == 'doc.type.boolean' then
+ return 'boolean'
end
+ if object.type == 'doc.type.integer' then
+ return 'integer'
+ end
+ if object.type == 'doc.type.function' then
+ return 'function'
+ end
+ if object.type == 'doc.type.table' then
+ return 'table'
+ end
+ if object.type == 'doc.type.array' then
+ return 'table'
+ end
+ if object.type == 'doc.type.string' then
+ return 'string'
+ end
+ return nil
+end
- if not child or not parent then
- return false
+---@param parentName string
+---@param child vm.node.object
+---@param uri uri
+---@return boolean?
+local function checkEnum(parentName, child, uri)
+ local parentClass = vm.getGlobal('type', parentName)
+ if not parentClass then
+ return nil
+ end
+ for _, set in ipairs(parentClass:getSets(uri)) do
+ if set.type == 'doc.enum' then
+ if not set._enums then
+ return false
+ end
+ if child.type ~= 'string'
+ and child.type ~= 'doc.type.string'
+ and child.type ~= 'integer'
+ and child.type ~= 'number'
+ and child.type ~= 'doc.type.integer' then
+ return false
+ end
+ return util.arrayHas(set._enums, child[1])
+ end
+ end
+
+ return nil
+end
+
+---@param parent vm.node.object
+---@param child vm.node.object
+---@return boolean
+local function checkValue(parent, child)
+ if parent.type == 'doc.type.integer' then
+ if child.type == 'integer'
+ or child.type == 'doc.type.integer'
+ or child.type == 'number' then
+ return parent[1] == child[1]
+ end
+ elseif parent.type == 'doc.type.string' then
+ if child.type == 'string'
+ or child.type == 'doc.type.string' then
+ return parent[1] == child[1]
+ end
end
+ return true
+end
+
+---@param uri uri
+---@param child vm.node|string|vm.node.object
+---@param parent vm.node|string|vm.node.object
+---@param mark? table
+---@return boolean
+function vm.isSubType(uri, child, parent, mark)
mark = mark or {}
- for obj in child:eachObject() do
- if obj.type ~= 'global'
- or obj.cate ~= 'type' then
- goto CONTINUE_CHILD
+
+ if type(child) == 'string' then
+ local global = vm.getGlobal('type', child)
+ if not global then
+ return false
+ end
+ child = global
+ elseif child.type == 'vm.node' then
+ if config.get(uri, 'Lua.type.weakUnionCheck') then
+ local hasKnownType
+ for n in child:eachObject() do
+ if getNodeName(n) then
+ hasKnownType = true
+ if vm.isSubType(uri, n, parent, mark) then
+ return true
+ end
+ end
+ end
+ return not hasKnownType
+ else
+ local weakNil = config.get(uri, 'Lua.type.weakNilCheck')
+ for n in child:eachObject() do
+ local nodeName = getNodeName(n)
+ if nodeName
+ and not (nodeName == 'nil' and weakNil)
+ and not vm.isSubType(uri, n, parent, mark) then
+ return false
+ end
+ end
+ if not weakNil and child:isOptional() then
+ if not vm.isSubType(uri, 'nil', parent, mark) then
+ return false
+ end
+ end
+ return true
end
- if mark[obj.name] then
+ end
+
+ if type(parent) == 'string' then
+ local global = vm.getGlobal('type', parent)
+ if not global then
return false
end
- mark[obj.name] = true
- for parentNode in parent:eachObject() do
- if parentNode.type ~= 'global'
- or parentNode.cate ~= 'type' then
- goto CONTINUE_PARENT
+ parent = global
+ elseif parent.type == 'vm.node' then
+ for n in parent:eachObject() do
+ if getNodeName(n)
+ and vm.isSubType(uri, child, n, mark) then
+ return true
end
- if parentNode.name == 'any' or obj.name == 'any' then
+ if n.type == 'doc.generic.name' then
return true
end
-
- if parentNode.name == obj.name then
+ end
+ if parent:isOptional() then
+ if vm.isSubType(uri, child, 'nil', mark) then
return true
end
+ end
+ return false
+ end
+
+ ---@cast child vm.node.object
+ ---@cast parent vm.node.object
- for _, set in ipairs(obj:getSets(uri)) do
+ local childName = getNodeName(child)
+ local parentName = getNodeName(parent)
+ if childName == 'any'
+ or parentName == 'any'
+ or childName == 'unknown'
+ or parentName == 'unknown'
+ or not childName
+ or not parentName then
+ return true
+ end
+
+ if childName == parentName then
+ if not checkValue(parent, child) then
+ return false
+ end
+ return true
+ end
+
+ if parentName == 'number' and childName == 'integer' then
+ return true
+ end
+
+ if parentName == 'integer' and childName == 'number' then
+ if config.get(uri, 'Lua.type.castNumberToInteger') then
+ return true
+ end
+ if child.type == 'number'
+ and child[1]
+ and not math.tointeger(child[1]) then
+ return false
+ end
+ if child.type == 'global'
+ and child.cate == 'type' then
+ return false
+ end
+ return true
+ end
+
+ local isEnum = checkEnum(parentName, child, uri)
+ if isEnum ~= nil then
+ return isEnum
+ end
+
+ -- TODO: check duck
+ if parentName == 'table' and not guide.isBasicType(childName) then
+ return true
+ end
+ if childName == 'table' and not guide.isBasicType(parentName) then
+ return true
+ end
+
+ -- check class parent
+ if childName and not mark[childName] then
+ mark[childName] = true
+ local isBasicType = guide.isBasicType(childName)
+ local childClass = vm.getGlobal('type', childName)
+ if childClass then
+ for _, set in ipairs(childClass:getSets(uri)) do
if set.type == 'doc.class' and set.extends then
for _, ext in ipairs(set.extends) do
if ext.type == 'doc.extends.name'
- and vm.isSubType(uri, ext[1], parentNode.name, mark) then
+ and (not isBasicType or guide.isBasicType(ext[1]))
+ and vm.isSubType(uri, ext[1], parent, mark) then
return true
end
end
end
- if set.type == 'doc.alias' and set.extends then
- for _, ext in ipairs(set.extends.types) do
- if ext.type == 'doc.type.name'
- and vm.isSubType(uri, ext[1], parentNode.name, mark) then
- return true
- end
- end
+ if set.type == 'doc.alias'
+ or set.type == 'doc.enum' then
+ return true
end
end
- ::CONTINUE_PARENT::
end
- ::CONTINUE_CHILD::
+ mark[childName] = nil
+ end
+
+ --[[
+ ---@class A: string
+
+ ---@type A
+ local x = '' --> `string` set to `A`
+ ]]
+ if guide.isBasicType(childName)
+ and guide.isLiteral(child)
+ and vm.isSubType(uri, parentName, childName) then
+ return true
end
return false
@@ -69,16 +254,24 @@ end
---@param uri uri
---@param tnode vm.node
----@param knode vm.node
+---@param knode vm.node|string
+---@param inversion? boolean
---@return vm.node?
-function vm.getTableValue(uri, tnode, knode)
+function vm.getTableValue(uri, tnode, knode, inversion)
local result = vm.createNode()
for tn in tnode:eachObject() do
if tn.type == 'doc.type.table' then
for _, field in ipairs(tn.fields) do
- if vm.isSubType(uri, vm.compileNode(field.name), knode) then
- if field.extends then
- result:merge(vm.compileNode(field.extends))
+ if field.name.type ~= 'doc.field.name'
+ and field.extends then
+ if inversion then
+ if vm.isSubType(uri, vm.compileNode(field.name), knode) then
+ result:merge(vm.compileNode(field.extends))
+ end
+ else
+ if vm.isSubType(uri, knode, vm.compileNode(field.name)) then
+ result:merge(vm.compileNode(field.extends))
+ end
end
end
end
@@ -88,25 +281,38 @@ function vm.getTableValue(uri, tnode, knode)
end
if tn.type == 'table' then
for _, field in ipairs(tn) do
- if field.type == 'tableindex' then
- if field.value then
- result:merge(vm.compileNode(field.value))
- end
+ if field.type == 'tableindex'
+ and field.value then
+ result:merge(vm.compileNode(field.value))
end
- if field.type == 'tablefield' then
- if vm.isSubType(uri, knode, 'string') then
- if field.value then
+ if field.type == 'tablefield'
+ and field.value then
+ if inversion then
+ if vm.isSubType(uri, 'string', knode) then
+ result:merge(vm.compileNode(field.value))
+ end
+ else
+ if vm.isSubType(uri, knode, 'string') then
result:merge(vm.compileNode(field.value))
end
end
end
- if field.type == 'tableexp' then
- if vm.isSubType(uri, knode, 'integer') and field.tindex == 1 then
- if field.value then
+ if field.type == 'tableexp'
+ and field.value
+ and field.tindex == 1 then
+ if inversion then
+ if vm.isSubType(uri, 'integer', knode) then
+ result:merge(vm.compileNode(field.value))
+ end
+ else
+ if vm.isSubType(uri, knode, 'integer') then
result:merge(vm.compileNode(field.value))
end
end
end
+ if field.type == 'varargs' then
+ result:merge(vm.compileNode(field))
+ end
end
end
end
@@ -118,16 +324,24 @@ end
---@param uri uri
---@param tnode vm.node
----@param vnode vm.node
+---@param vnode vm.node|string|vm.object
+---@param reverse? boolean
---@return vm.node?
-function vm.getTableKey(uri, tnode, vnode)
+function vm.getTableKey(uri, tnode, vnode, reverse)
local result = vm.createNode()
for tn in tnode:eachObject() do
if tn.type == 'doc.type.table' then
for _, field in ipairs(tn.fields) do
- if field.extends then
- if vm.isSubType(uri, vm.compileNode(field.extends), vnode) then
- result:merge(vm.compileNode(field.name))
+ if field.name.type ~= 'doc.field.name'
+ and field.extends then
+ if reverse then
+ if vm.isSubType(uri, vm.compileNode(field.extends), vnode) then
+ result:merge(vm.compileNode(field.name))
+ end
+ else
+ if vm.isSubType(uri, vnode, vm.compileNode(field.extends)) then
+ result:merge(vm.compileNode(field.name))
+ end
end
end
end
@@ -156,3 +370,49 @@ function vm.getTableKey(uri, tnode, vnode)
end
return result
end
+
+---@param uri uri
+---@param defNode vm.node
+---@param refNode vm.node
+---@return boolean
+function vm.canCastType(uri, defNode, refNode)
+ local defInfer = vm.getInfer(defNode)
+ local refInfer = vm.getInfer(refNode)
+
+ if defInfer:hasAny(uri) then
+ return true
+ end
+ if refInfer:hasAny(uri) then
+ return true
+ end
+ if defInfer:view(uri) == 'unknown' then
+ return true
+ end
+ if refInfer:view(uri) == 'unknown' then
+ return true
+ end
+
+ if vm.isSubType(uri, refNode, 'nil') then
+ -- allow `local x = {};x = nil`,
+ -- but not allow `local x ---@type table;x = nil`
+ if defInfer:hasType(uri, 'table')
+ and not defNode:hasType 'table' then
+ return true
+ end
+ end
+
+ if vm.isSubType(uri, refNode, 'number') then
+ -- allow `local x = 0;x = 1.0`,
+ -- but not allow `local x ---@type integer;x = 1.0`
+ if defInfer:hasType(uri, 'integer')
+ and not defNode:hasType 'integer' then
+ return true
+ end
+ end
+
+ if vm.isSubType(uri, refNode, defNode) then
+ return true
+ end
+
+ return false
+end
diff --git a/script/vm/value.lua b/script/vm/value.lua
index d29ca9d0..7eab4a8e 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -4,11 +4,14 @@ local vm = require 'vm.vm'
---@param source parser.object?
---@return boolean|nil
-function vm.test(source)
+function vm.testCondition(source)
if not source then
return nil
end
local node = vm.compileNode(source)
+ if node.optional then
+ return nil
+ end
local hasTrue, hasFalse
for n in node:eachObject() do
if n.type == 'boolean'
@@ -19,24 +22,20 @@ function vm.test(source)
if n[1] == false then
hasFalse = true
end
- end
- if n.type == 'global' and n.cate == 'type' then
- if n.name == 'true' then
- hasTrue = true
+ elseif n.type == 'global' and n.cate == 'type' then
+ if n.name == 'boolean'
+ or n.name == 'unknown' then
+ return nil
end
if n.name == 'false'
or n.name == 'nil' then
hasFalse = true
+ else
+ hasTrue = true
end
- end
- if n.type == 'nil' then
+ elseif n.type == 'nil' then
hasFalse = true
- end
- if n.type == 'string'
- or n.type == 'number'
- or n.type == 'integer'
- or n.type == 'table'
- or n.type == 'function' then
+ elseif guide.isLiteral(n) then
hasTrue = true
end
end
@@ -50,8 +49,8 @@ function vm.test(source)
end
end
----@param v vm.object
----@return string?
+---@param v vm.node.object
+---@return string|false
local function getUnique(v)
if v.type == 'boolean' then
if v[1] == nil then
@@ -72,16 +71,18 @@ local function getUnique(v)
return ('num:%s'):format(v[1])
end
if v.type == 'table' then
+ ---@cast v parser.object
return ('table:%s@%d'):format(guide.getUri(v), v.start)
end
if v.type == 'function' then
+ ---@cast v parser.object
return ('func:%s@%d'):format(guide.getUri(v), v.start)
end
return false
end
----@param a vm.object?
----@param b vm.object?
+---@param a parser.object?
+---@param b parser.object?
---@return boolean|nil
function vm.equal(a, b)
if not a or not b then
@@ -141,7 +142,7 @@ function vm.getInteger(v)
end
---@param v vm.object?
----@return integer?
+---@return string?
function vm.getString(v)
if not v then
return nil
diff --git a/script/vm/vm.lua b/script/vm/vm.lua
index 8117d311..5437b632 100644
--- a/script/vm/vm.lua
+++ b/script/vm/vm.lua
@@ -64,6 +64,20 @@ function m.getObjectValue(source)
return nil
end
+---@param source parser.object
+---@return parser.object?
+function m.getObjectFunctionValue(source)
+ local value = m.getObjectValue(source)
+ if value == nil then return end
+ if value.type == 'function' or value.type == 'doc.type.function' then
+ return value
+ end
+ if value.type == 'getlocal' then
+ return m.getObjectFunctionValue(value.node)
+ end
+ return value
+end
+
m.cacheTracker = setmetatable({}, weakMT)
function m.flushCache()
diff --git a/script/workspace/require-path.lua b/script/workspace/require-path.lua
index aec298a6..1b7b25f9 100644
--- a/script/workspace/require-path.lua
+++ b/script/workspace/require-path.lua
@@ -3,30 +3,43 @@ local files = require 'files'
local furi = require 'file-uri'
local workspace = require "workspace"
local config = require 'config'
-local collector = require 'core.collector'
local scope = require 'workspace.scope'
+local util = require 'utility'
---@class require-path
local m = {}
-local function addRequireName(suri, uri, name)
- local separator = config.get(uri, 'Lua.completion.requireSeparator')
- local fsname = name:gsub('%' .. separator, '/')
- local scp = scope.getScope(suri)
- ---@type collector
- local clt = scp:get('requireName') or scp:set('requireName', collector())
- clt:subscribe(uri, fsname, name)
+---@class require-manager
+---@field scp scope
+---@field nameMap table<string, string>
+---@field visibleCache table<string, require-manager.visibleResult[]>
+local mt = {}
+mt.__index = mt
+
+---@alias require-manager.visibleResult { searcher: string, name: string }
+
+---@param scp scope
+---@return require-manager
+local function createRequireManager(scp)
+ return setmetatable({
+ scp = scp,
+ nameMap = {},
+ visibleCache = {},
+ }, mt)
end
--- `aaa/bbb/ccc.lua` 与 `?.lua` 将返回 `aaa.bbb.cccc`
-local function getOnePath(uri, path, searcher)
- local separator = config.get(uri, 'Lua.completion.requireSeparator')
+---@param path string
+---@param searcher string
+---@return string?
+function mt:getRequireNameByPath(path, searcher)
+ local separator = config.get(self.scp.uri, 'Lua.completion.requireSeparator')
local stemPath = path
: gsub('%.[^%.]+$', '')
: gsub('[/\\%.]+', separator)
local stemSearcher = searcher
: gsub('%.[^%.]+$', '')
- : gsub('[/\\%.]+', separator)
+ : gsub('[/\\]+', separator)
local start = stemSearcher:match '()%?' or 1
if stemPath:sub(1, start - 1) ~= stemSearcher:sub(1, start - 1) then
return nil
@@ -41,162 +54,188 @@ local function getOnePath(uri, path, searcher)
return nil
end
-function m.getVisiblePath(suri, path)
- local searchers = config.get(suri, 'Lua.runtime.path')
- local strict = config.get(suri, 'Lua.runtime.pathStrict')
- path = workspace.normalize(path)
+---@param path string
+---@return require-manager.visibleResult[]
+function mt:getRequireResultByPath(path)
local uri = furi.encode(path)
- local scp = scope.getScope(suri)
- if not scp:isChildUri(uri)
- and not scp:isLinkedUri(uri) then
- return {}
- end
- local libraryPath = furi.decode(files.getLibraryUri(suri, uri))
- local cache = scp:get('visiblePath') or scp:set('visiblePath', {})
- local result = cache[path]
- if not result then
- result = {}
- cache[path] = result
- for _, searcher in ipairs(searchers) do
- local isAbsolute = searcher:match '^[/\\]'
- or searcher:match '^%a+%:'
- searcher = workspace.normalize(searcher)
- local cutedPath = path
- local currentPath = path
- local head
- local pos = 1
- if not isAbsolute then
- if libraryPath then
- currentPath = currentPath:sub(#libraryPath + 2)
- else
- currentPath = workspace.getRelativePath(uri)
- end
+ local searchers = config.get(self.scp.uri, 'Lua.runtime.path')
+ local strict = config.get(self.scp.uri, 'Lua.runtime.pathStrict')
+ local libUri = files.getLibraryUri(self.scp.uri, uri)
+ local libraryPath = libUri and furi.decode(libUri)
+ local result = {}
+ for _, searcher in ipairs(searchers) do
+ local isAbsolute = searcher:match '^[/\\]'
+ or searcher:match '^%a+%:'
+ searcher = workspace.normalize(searcher)
+ if searcher:sub(1, 1) == '.' then
+ strict = true
+ end
+ local cutedPath = path
+ local currentPath = path
+ local head
+ local pos = 1
+ if not isAbsolute then
+ if libraryPath then
+ currentPath = currentPath:sub(#libraryPath + 2)
+ else
+ currentPath = workspace.getRelativePath(uri)
end
- repeat
- cutedPath = currentPath:sub(pos)
- head = currentPath:sub(1, pos - 1)
- pos = currentPath:match('[/\\]+()', pos)
- if platform.OS == 'Windows' then
- searcher = searcher :gsub('[/\\]+', '\\')
- else
- searcher = searcher :gsub('[/\\]+', '/')
- end
- local expect = getOnePath(suri, cutedPath, searcher)
- if expect then
- local mySearcher = searcher
- if head then
- mySearcher = head .. searcher
- end
- result[#result+1] = {
- searcher = mySearcher,
- expect = expect,
- }
- addRequireName(suri, uri, expect)
- end
- until not pos or strict
end
- end
- return result
-end
---- 查找符合指定require path的所有uri
----@param path string
-function m.findUrisByRequirePath(suri, path)
- if type(path) ~= 'string' then
- return {}
- end
- local separator = config.get(suri, 'Lua.completion.requireSeparator')
- local fspath = path:gsub('%' .. separator, '/')
- tracy.ZoneBeginN('findUrisByRequirePath')
- local results = {}
- local searchers = {}
- for uri in files.eachDll() do
- local opens = files.getDllOpens(uri) or {}
- for _, open in ipairs(opens) do
- if open == fspath then
- results[#results+1] = uri
+ -- handle `../?.lua`
+ local parentCount = 0
+ for _ = 1, 1000 do
+ if searcher:match '^%.%.[/\\]' then
+ parentCount = parentCount + 1
+ searcher = searcher:sub(4)
+ else
+ break
end
end
- end
-
- ---@type collector
- local clt = scope.getScope(suri):get('requireName')
- if clt then
- for _, uri in clt:each(suri, fspath) do
- if uri ~= suri then
- local infos = m.getVisiblePath(suri, furi.decode(uri))
- for _, info in ipairs(infos) do
- local fsexpect = info.expect:gsub('%' .. separator, '/')
- if fsexpect == fspath then
- results[#results+1] = uri
- searchers[uri] = info.searcher
- end
+ if parentCount > 0 then
+ local parentPath = libraryPath
+ or (self.scp.uri and furi.decode(self.scp.uri))
+ if parentPath then
+ local tail
+ for _ = 1, parentCount do
+ parentPath, tail = parentPath:match '^(.+)[/\\]([^/\\]*)$'
+ currentPath = tail .. '/' .. currentPath
end
end
end
+
+ repeat
+ cutedPath = currentPath:sub(pos)
+ head = currentPath:sub(1, pos - 1)
+ pos = currentPath:match('[/\\]+()', pos)
+ if platform.OS == 'Windows' then
+ searcher = searcher :gsub('[/\\]+', '\\')
+ else
+ searcher = searcher :gsub('[/\\]+', '/')
+ end
+ local name = self:getRequireNameByPath(cutedPath, searcher)
+ if name then
+ local mySearcher = searcher
+ if head then
+ mySearcher = head .. searcher
+ end
+ result[#result+1] = {
+ name = name,
+ searcher = mySearcher,
+ }
+ end
+ until not pos or strict
end
+ return result
+end
- tracy.ZoneEnd()
- return results, searchers
+---@param name string
+function mt:addName(name)
+ local separator = config.get(self.scp.uri, 'Lua.completion.requireSeparator')
+ local fsname = name:gsub('%' .. separator, '/')
+ self.nameMap[fsname] = name
end
-local function createVisiblePath(uri)
- for _, scp in ipairs(workspace.folders) do
- m.getVisiblePath(scp.uri, furi.decode(uri))
+---@return require-manager.visibleResult[]
+function mt:getVisiblePath(path)
+ local uri = furi.encode(path)
+ if not self.scp:isChildUri(uri)
+ and not self.scp:isLinkedUri(uri) then
+ return {}
+ end
+ path = workspace.normalize(path)
+ local result = self.visibleCache[path]
+ if not result then
+ result = self:getRequireResultByPath(path)
+ self.visibleCache[path] = result
end
- m.getVisiblePath(nil, furi.decode(uri))
+ return result
end
-local function removeVisiblePath(uri)
- local path = furi.decode(uri)
- path = workspace.normalize(path)
- if not path then
- return
+--- 查找符合指定require name的所有uri
+---@param suri uri
+---@param name string
+---@return uri[]
+---@return table<uri, string>?
+function mt:findUrisByRequireName(suri, name)
+ if type(name) ~= 'string' then
+ return {}
end
- for _, scp in ipairs(workspace.folders) do
- if scp:get('visiblePath') then
- scp:get('visiblePath')[path] = nil
+ local searchers = config.get(self.scp.uri, 'Lua.runtime.path')
+ local strict = config.get(self.scp.uri, 'Lua.runtime.pathStrict')
+ local separator = config.get(self.scp.uri, 'Lua.completion.requireSeparator')
+ local path = name:gsub('%' .. separator, '/')
+ local results = {}
+ local searcherMap = {}
+
+ for _, searcher in ipairs(searchers) do
+ local fspath = searcher:gsub('%?', (path:gsub('%%', '%%%%')))
+ local fullPath = workspace.getAbsolutePath(self.scp.uri, fspath)
+ if fullPath then
+ local fullUri = furi.encode(fullPath)
+ if files.exists(fullUri)
+ and fullUri ~= suri then
+ results[#results+1] = fullUri
+ searcherMap[fullUri] = searcher
+ end
end
- ---@type collector
- local clt = scp:get('requireName')
- if clt then
- clt:dropUri(uri)
+ if not strict then
+ local tail = '/' .. furi.encode(fspath):gsub('^file:[/]*', '')
+ for uri in files.eachFile(self.scp.uri) do
+ if not searcherMap[uri]
+ and suri ~= uri
+ and util.stringEndWith(uri, tail) then
+ results[#results+1] = uri
+ local parentUri = files.getLibraryUri(self.scp.uri, uri) or self.scp.uri
+ if parentUri == nil or parentUri == '' then
+ parentUri = furi.encode ''
+ end
+ local relative = uri:sub(#parentUri + 1):sub(1, - #tail)
+ searcherMap[uri] = workspace.normalize(relative .. searcher)
+ end
+ end
end
end
- if scope.fallback:get('visiblePath') then
- scope.fallback:get('visiblePath')[path] = nil
- end
- ---@type collector
- local clt = scope.fallback:get('requireName')
- if clt then
- clt:dropUri(uri)
+
+ for uri in files.eachDll() do
+ local opens = files.getDllOpens(uri) or {}
+ for _, open in ipairs(opens) do
+ if open == path then
+ results[#results+1] = uri
+ end
+ end
end
+
+ return results, searcherMap
end
-function m.flush(suri)
- local scp = scope.getScope(suri)
- scp:set('visiblePath', {})
- ---@type collector
- local clt = scp:get('requireName')
- if clt then
- clt:dropAll()
- end
- for uri in files.eachFile(suri) do
- m.getVisiblePath(scp.uri, furi.decode(uri))
- end
+---@param uri uri
+---@param path string
+---@return require-manager.visibleResult[]
+function m.getVisiblePath(uri, path)
+ local scp = scope.getScope(uri)
+ ---@type require-manager
+ local mgr = scp:get 'requireManager'
+ or scp:set('requireManager', createRequireManager(scp))
+ return mgr:getVisiblePath(path)
end
-for _, scp in ipairs(scope.folders) do
- m.flush(scp.uri)
+---@param uri uri
+---@param name string
+function m.findUrisByRequireName(uri, name)
+ local scp = scope.getScope(uri)
+ ---@type require-manager
+ local mgr = scp:get 'requireManager'
+ or scp:set('requireManager', createRequireManager(scp))
+ return mgr:findUrisByRequireName(uri, name)
end
-m.flush(nil)
files.watch(function (ev, uri)
- if ev == 'create' then
- createVisiblePath(uri)
- end
- if ev == 'remove' then
- removeVisiblePath(uri)
+ if ev == 'create' or ev == 'delete' then
+ for _, scp in ipairs(workspace.folders) do
+ scp:set('requireManager', nil)
+ end
+ scope.fallback:set('requireManager', nil)
end
end)
@@ -204,7 +243,8 @@ config.watch(function (uri, key, value, oldValue)
if key == 'Lua.completion.requireSeparator'
or key == 'Lua.runtime.path'
or key == 'Lua.runtime.pathStrict' then
- m.flush(uri)
+ local scp = scope.getScope(uri)
+ scp:set('requireManager', nil)
end
end)
diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua
index a0f4fbf7..4649d354 100644
--- a/script/workspace/scope.lua
+++ b/script/workspace/scope.lua
@@ -11,6 +11,7 @@ local m = {}
---@field _links table<uri, boolean>
---@field _data table<string, any>
---@field _gc gc
+---@field _removed? true
local mt = {}
mt.__index = mt
@@ -85,7 +86,7 @@ function mt:getLinkedUri(uri)
end
---@param uri uri
----@return uri
+---@return uri?
function mt:getRootUri(uri)
if self:isChildUri(uri) then
return self.uri
@@ -117,9 +118,30 @@ end
function mt:flushGC()
self._gc:remove()
+ if self._removed then
+ return
+ end
self._gc = gc()
end
+function mt:remove()
+ if self._removed then
+ return
+ end
+ self._removed = true
+ for i, scp in ipairs(m.folders) do
+ if scp == self then
+ table.remove(m.folders, i)
+ break
+ end
+ end
+ self:flushGC()
+end
+
+function mt:isRemoved()
+ return self._removed == true
+end
+
---@param scopeType scope.type
---@return scope
local function createScope(scopeType)
@@ -164,7 +186,7 @@ function m.createFolder(uri)
end
---@param uri uri
----@return scope
+---@return scope?
function m.getFolder(uri)
for _, scope in ipairs(m.folders) do
if scope:isChildUri(uri) then
@@ -175,7 +197,7 @@ function m.getFolder(uri)
end
---@param uri uri
----@return scope
+---@return scope?
function m.getLinkedScope(uri)
if m.override and m.override:isLinkedUri(uri) then
return m.override
@@ -188,6 +210,7 @@ function m.getLinkedScope(uri)
if m.fallback:isLinkedUri(uri) then
return m.fallback
end
+ return nil
end
---@param uri uri
diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua
index 33f8784d..9d2ad637 100644
--- a/script/workspace/workspace.lua
+++ b/script/workspace/workspace.lua
@@ -7,12 +7,12 @@ local glob = require 'glob'
local platform = require 'bee.platform'
local await = require 'await'
local client = require 'client'
-local plugin = require 'plugin'
local util = require 'utility'
local fw = require 'filewatch'
local scope = require 'workspace.scope'
local loading = require 'workspace.loading'
local inspect = require 'inspect'
+local lang = require 'language'
---@class workspace
local m = {}
@@ -45,13 +45,35 @@ end
--- 初始化工作区
function m.create(uri)
+ if furi.isValid(uri) then
+ uri = furi.normalize(uri)
+ end
log.info('Workspace create: ', uri)
- local path = m.normalize(furi.decode(uri))
- fw.watch(path)
+ if uri == furi.encode '/'
+ or uri == furi.encode(os.getenv 'HOME' or '') then
+ client.showMessage('Error', lang.script('WORKSPACE_NOT_ALLOWED', furi.decode(uri)))
+ return
+ end
local scp = scope.createFolder(uri)
m.folders[#m.folders+1] = scp
end
+function m.remove(uri)
+ log.info('Workspace remove: ', uri)
+ for i, scp in ipairs(m.folders) do
+ if scp.uri == uri then
+ scp:remove()
+ table.remove(m.folders, i)
+ scp:set('ready', false)
+ scp:set('nativeMatcher', nil)
+ scp:set('libraryMatcher', nil)
+ scp:removeAllLinks()
+ m.flushFiles(scp)
+ return
+ end
+ end
+end
+
function m.reset()
---@type scope[]
m.folders = {}
@@ -135,7 +157,7 @@ function m.getNativeMatcher(scp)
end
end
end
- for path in pairs(config.get(scp.uri, 'Lua.workspace.library')) do
+ for _, path in ipairs(config.get(scp.uri, 'Lua.workspace.library')) do
path = m.getAbsolutePath(scp.uri, path)
if path then
log.debug('Ignore by library:', path)
@@ -148,7 +170,7 @@ function m.getNativeMatcher(scp)
end
local matcher = glob.gitignore(pattern, {
- root = furi.decode(scp.uri),
+ root = scp.uri and furi.decode(scp.uri),
ignoreCase = platform.OS == 'Windows',
}, globInteferFace)
@@ -177,7 +199,7 @@ function m.getLibraryMatchers(scp)
end
local librarys = {}
- for path in pairs(config.get(scp.uri, 'Lua.workspace.library')) do
+ for _, path in ipairs(config.get(scp.uri, 'Lua.workspace.library')) do
path = m.getAbsolutePath(scp.uri, path)
if path then
librarys[m.normalize(path)] = true
@@ -273,6 +295,10 @@ function m.awaitPreload(scp)
scp:flushGC()
+ if scp:isRemoved() then
+ return
+ end
+
local ld <close> = loading.create(scp)
scp:set('loading', ld)
@@ -283,22 +309,35 @@ function m.awaitPreload(scp)
if scp.uri then
log.info('Scan files at:', scp:getName())
+ local count = 0
---@async
native:scan(furi.decode(scp.uri), function (path)
local uri = files.getRealUri(furi.encode(path))
scp:get('cachedUris')[uri] = true
ld:loadFile(uri)
+ end, function () ---@async
+ count = count + 1
+ if count == 100000 then
+ client.showMessage('Warning', lang.script('WORKSPACE_SCAN_TOO_MUCH', count, furi.decode(scp.uri)))
+ end
end)
+ scp:gc(fw.watch(m.normalize(furi.decode(scp.uri))))
end
for _, libMatcher in ipairs(librarys) do
log.info('Scan library at:', libMatcher.uri)
+ local count = 0
scp:addLink(libMatcher.uri)
---@async
libMatcher.matcher:scan(furi.decode(libMatcher.uri), function (path)
local uri = files.getRealUri(furi.encode(path))
scp:get('cachedUris')[uri] = true
ld:loadFile(uri, libMatcher.uri)
+ end, function () ---@async
+ count = count + 1
+ if count == 100000 then
+ client.showMessage('Warning', lang.script('WORKSPACE_SCAN_TOO_MUCH', count, furi.decode(libMatcher.uri)))
+ end
end)
scp:gc(fw.watch(furi.decode(libMatcher.uri)))
end
@@ -336,9 +375,6 @@ end
---@param path string
---@return string
function m.normalize(path)
- if not path then
- return nil
- end
path = path:gsub('%$%{(.-)%}', function (key)
if key == '3rd' then
return (ROOT / 'meta' / '3rd'):string()
@@ -350,9 +386,20 @@ function m.normalize(path)
end)
path = util.expandPath(path)
path = path:gsub('^%.[/\\]+', '')
+ for _ = 1, 1000 do
+ if path:sub(1, 2) == '..' then
+ break
+ end
+ local count
+ path, count = path:gsub('[^/\\]+[/\\]+%.%.[/\\]', '/', 1)
+ if count == 0 then
+ break
+ end
+ end
if platform.OS == 'Windows' then
path = path:gsub('[/\\]+', '\\')
:gsub('[/\\]+$', '')
+ :gsub('^(%a:)$', '%1\\')
else
path = path:gsub('[/\\]+', '/')
:gsub('[/\\]+$', '')
@@ -360,11 +407,10 @@ function m.normalize(path)
return path
end
----@return string
+---@param folderUri? uri
+---@param path string
+---@return string?
function m.getAbsolutePath(folderUri, path)
- if not path or path == '' then
- return nil
- end
path = m.normalize(path)
if fs.path(path):is_relative() then
if not folderUri then
@@ -378,6 +424,7 @@ end
---@param uriOrPath uri|string
---@return string
+---@return boolean suc
function m.getRelativePath(uriOrPath)
local path, uri
if uriOrPath:sub(1, 5) == 'file:' then
@@ -427,16 +474,24 @@ function m.flushFiles(scp)
for uri in pairs(cachedUris) do
files.delRef(uri)
end
+ collectgarbage()
+ collectgarbage()
+ -- TODO: wait maillist
+ collectgarbage 'restart'
end
---@param scp scope
function m.resetFiles(scp)
local cachedUris = scp:get 'cachedUris'
- if not cachedUris then
- return
+ if cachedUris then
+ for uri in pairs(cachedUris) do
+ files.resetText(uri)
+ end
end
- for uri in pairs(cachedUris) do
- files.resetText(uri)
+ for uri in pairs(files.openMap) do
+ if scope.getScope(uri) == scp then
+ files.resetText(uri)
+ end
end
end
@@ -448,7 +503,7 @@ function m.awaitReload(scp)
scp:set('libraryMatcher', nil)
scp:removeAllLinks()
m.flushFiles(scp)
- plugin.init(scp)
+ m.onWatch('startReload', scp.uri)
m.awaitPreload(scp)
scp:set('ready', true)
local waiting = scp:get('waitingReady')
@@ -501,6 +556,7 @@ end
config.watch(function (uri, key, value, oldValue)
if key:find '^Lua.runtime'
or key:find '^Lua.workspace'
+ or key:find '^Lua.type'
or key:find '^files' then
if value ~= oldValue then
m.reload(scope.getScope(uri))