summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/await.lua227
-rw-r--r--script/brave/brave.lua70
-rw-r--r--script/brave/init.lua4
-rw-r--r--script/brave/log.lua54
-rw-r--r--script/brave/work.lua60
-rw-r--r--script/config.lua218
-rw-r--r--script/core/code-action.lua269
-rw-r--r--script/core/command/removeSpace.lua56
-rw-r--r--script/core/command/solve.lua96
-rw-r--r--script/core/completion.lua1284
-rw-r--r--script/core/definition.lua156
-rw-r--r--script/core/diagnostics/ambiguity-1.lua69
-rw-r--r--script/core/diagnostics/circle-doc-class.lua54
-rw-r--r--script/core/diagnostics/code-after-break.lua34
-rw-r--r--script/core/diagnostics/doc-field-no-class.lua41
-rw-r--r--script/core/diagnostics/duplicate-doc-class.lua46
-rw-r--r--script/core/diagnostics/duplicate-doc-field.lua34
-rw-r--r--script/core/diagnostics/duplicate-doc-param.lua37
-rw-r--r--script/core/diagnostics/duplicate-index.lua63
-rw-r--r--script/core/diagnostics/empty-block.lua49
-rw-r--r--script/core/diagnostics/global-in-nil-env.lua66
-rw-r--r--script/core/diagnostics/init.lua56
-rw-r--r--script/core/diagnostics/lowercase-global.lua56
-rw-r--r--script/core/diagnostics/newfield-call.lua37
-rw-r--r--script/core/diagnostics/newline-call.lua38
-rw-r--r--script/core/diagnostics/redefined-local.lua32
-rw-r--r--script/core/diagnostics/redundant-parameter.lua82
-rw-r--r--script/core/diagnostics/redundant-value.lua24
-rw-r--r--script/core/diagnostics/trailing-space.lua55
-rw-r--r--script/core/diagnostics/undefined-doc-class.lua46
-rw-r--r--script/core/diagnostics/undefined-doc-name.lua60
-rw-r--r--script/core/diagnostics/undefined-doc-param.lua52
-rw-r--r--script/core/diagnostics/undefined-env-child.lua27
-rw-r--r--script/core/diagnostics/undefined-global.lua40
-rw-r--r--script/core/diagnostics/unused-function.lua40
-rw-r--r--script/core/diagnostics/unused-label.lua22
-rw-r--r--script/core/diagnostics/unused-local.lua93
-rw-r--r--script/core/diagnostics/unused-vararg.lua31
-rw-r--r--script/core/document-symbol.lua307
-rw-r--r--script/core/find-source.lua14
-rw-r--r--script/core/highlight.lua252
-rw-r--r--script/core/hover/arg.lua71
-rw-r--r--script/core/hover/description.lua204
-rw-r--r--script/core/hover/init.lua164
-rw-r--r--script/core/hover/label.lua211
-rw-r--r--script/core/hover/name.lua101
-rw-r--r--script/core/hover/return.lua125
-rw-r--r--script/core/hover/table.lua257
-rw-r--r--script/core/keyword.lua264
-rw-r--r--script/core/matchkey.lua33
-rw-r--r--script/core/reference.lua116
-rw-r--r--script/core/rename.lua448
-rw-r--r--script/core/semantic-tokens.lua161
-rw-r--r--script/core/signature.lua106
-rw-r--r--script/core/workspace-symbol.lua69
-rw-r--r--script/doctor.lua380
-rw-r--r--script/file-uri.lua89
-rw-r--r--script/files.lua438
-rw-r--r--script/fs-utility.lua559
-rw-r--r--script/glob/gitignore.lua221
-rw-r--r--script/glob/glob.lua122
-rw-r--r--script/glob/init.lua4
-rw-r--r--script/glob/matcher.lua151
-rw-r--r--script/json-beautify.lua120
-rw-r--r--script/json.lua450
-rw-r--r--script/jsonrpc.lua64
-rw-r--r--script/language.lua140
-rw-r--r--script/library.lua205
-rw-r--r--script/log.lua142
-rw-r--r--script/parser/ast.lua1751
-rw-r--r--script/parser/calcline.lua94
-rw-r--r--script/parser/compile.lua561
-rw-r--r--script/parser/grammar.lua538
-rw-r--r--script/parser/guide.lua3884
-rw-r--r--script/parser/init.lua12
-rw-r--r--script/parser/lines.lua45
-rw-r--r--script/parser/luadoc.lua991
-rw-r--r--script/parser/parse.lua49
-rw-r--r--script/parser/relabel.lua361
-rw-r--r--script/parser/split.lua9
-rw-r--r--script/proto/define.lua287
-rw-r--r--script/proto/init.lua3
-rw-r--r--script/proto/proto.lua147
-rw-r--r--script/provider/capability.lua61
-rw-r--r--script/provider/client.lua18
-rw-r--r--script/provider/completion.lua54
-rw-r--r--script/provider/diagnostic.lua303
-rw-r--r--script/provider/init.lua1
-rw-r--r--script/provider/markdown.lua26
-rw-r--r--script/provider/provider.lua642
-rw-r--r--script/provider/semantic-tokens.lua64
-rw-r--r--script/pub/init.lua4
-rw-r--r--script/pub/pub.lua242
-rw-r--r--script/pub/report.lua26
-rw-r--r--script/service/init.lua3
-rw-r--r--script/service/service.lua158
-rw-r--r--script/timer.lua218
-rw-r--r--script/utility.lua559
-rw-r--r--script/vm/eachDef.lua40
-rw-r--r--script/vm/eachField.lua45
-rw-r--r--script/vm/eachRef.lua39
-rw-r--r--script/vm/getClass.lua62
-rw-r--r--script/vm/getDocs.lua175
-rw-r--r--script/vm/getGlobals.lua192
-rw-r--r--script/vm/getInfer.lua96
-rw-r--r--script/vm/getLibrary.lua32
-rw-r--r--script/vm/getLinks.lua61
-rw-r--r--script/vm/getMeta.lua52
-rw-r--r--script/vm/guideInterface.lua106
-rw-r--r--script/vm/init.lua13
-rw-r--r--script/vm/vm.lua167
-rw-r--r--script/without-check-nil.lua126
-rw-r--r--script/workspace/init.lua3
-rw-r--r--script/workspace/require-path.lua74
-rw-r--r--script/workspace/workspace.lua330
115 files changed, 22490 insertions, 0 deletions
diff --git a/script/await.lua b/script/await.lua
new file mode 100644
index 00000000..d8e2a9ad
--- /dev/null
+++ b/script/await.lua
@@ -0,0 +1,227 @@
+local timer = require 'timer'
+local util = require 'utility'
+
+---@class await
+local m = {}
+m.type = 'await'
+
+m.coMap = setmetatable({}, { __mode = 'k' })
+m.idMap = {}
+m.delayQueue = {}
+m.delayQueueIndex = 1
+m.watchList = {}
+m._enable = true
+
+--- 设置错误处理器
+---@param errHandle function {comment = '当有错误发生时,会以错误堆栈为参数调用该函数'}
+function m.setErrorHandle(errHandle)
+ m.errorHandle = errHandle
+end
+
+function m.checkResult(co, ...)
+ local suc, err = ...
+ if not suc and m.errorHandle then
+ m.errorHandle(debug.traceback(co, err))
+ end
+ return ...
+end
+
+--- 创建一个任务
+function m.call(callback, ...)
+ local co = coroutine.create(callback)
+ local closers = {}
+ m.coMap[co] = {
+ closers = closers,
+ priority = false,
+ }
+ for i = 1, select('#', ...) do
+ local id = select(i, ...)
+ if not id then
+ break
+ end
+ m.setID(id, co)
+ end
+
+ local currentCo = coroutine.running()
+ local current = m.coMap[currentCo]
+ if current then
+ for closer in pairs(current.closers) do
+ closers[closer] = true
+ closer(co)
+ end
+ end
+ return m.checkResult(co, coroutine.resume(co))
+end
+
+--- 创建一个任务,并挂起当前线程,当任务完成后再延续当前线程/若任务被关闭,则返回nil
+function m.await(callback, ...)
+ if not coroutine.isyieldable() then
+ return callback(...)
+ end
+ return m.wait(function (waker, ...)
+ m.call(function ()
+ local returnNil <close> = util.defer(waker)
+ waker(callback())
+ end, ...)
+ end, ...)
+end
+
+--- 设置一个id,用于批量关闭任务
+function m.setID(id, co)
+ co = co or coroutine.running()
+ if not m.idMap[id] then
+ m.idMap[id] = setmetatable({}, { __mode = 'k' })
+ end
+ m.idMap[id][co] = true
+end
+
+--- 根据id批量关闭任务
+function m.close(id)
+ local map = m.idMap[id]
+ if not map then
+ return
+ end
+ local count = 0
+ for co in pairs(map) do
+ map[co] = nil
+ coroutine.close(co)
+ count = count + 1
+ end
+ log.debug('Close await:', id, count)
+end
+
+function m.hasID(id, co)
+ co = co or coroutine.running()
+ return m.idMap[id] and m.idMap[id][co] ~= nil
+end
+
+--- 休眠一段时间
+---@param time number
+function m.sleep(time)
+ if not coroutine.isyieldable() then
+ if m.errorHandle then
+ m.errorHandle(debug.traceback('Cannot yield'))
+ end
+ return
+ end
+ local co = coroutine.running()
+ timer.wait(time, function ()
+ if coroutine.status(co) ~= 'suspended' then
+ return
+ end
+ return m.checkResult(co, coroutine.resume(co))
+ end)
+ return coroutine.yield()
+end
+
+--- 等待直到唤醒
+---@param callback function
+function m.wait(callback, ...)
+ if not coroutine.isyieldable() then
+ return
+ end
+ local co = coroutine.running()
+ local waked
+ callback(function (...)
+ if waked then
+ return
+ end
+ waked = true
+ if coroutine.status(co) ~= 'suspended' then
+ return
+ end
+ return m.checkResult(co, coroutine.resume(co, ...))
+ end, ...)
+ return coroutine.yield()
+end
+
+--- 延迟
+function m.delay()
+ if not m._enable then
+ return
+ end
+ if not coroutine.isyieldable() then
+ return
+ end
+ local co = coroutine.running()
+ local current = m.coMap[co]
+ if m.onWatch('delay', co) == false then
+ return
+ end
+ -- TODO
+ if current.priority then
+ return
+ end
+ m.delayQueue[#m.delayQueue+1] = function ()
+ if coroutine.status(co) ~= 'suspended' then
+ return
+ end
+ return m.checkResult(co, coroutine.resume(co))
+ end
+ return coroutine.yield()
+end
+
+local function warnStepTime(passed, waker)
+ if passed < 1 then
+ log.warn(('Await step takes [%.3f] sec.'):format(passed))
+ return
+ end
+ for i = 1, 100 do
+ local name, v = debug.getupvalue(waker, i)
+ if not name then
+ return
+ end
+ if name == 'co' then
+ log.warn(debug.traceback(v, ('[fire]Await step takes [%.3f] sec.'):format(passed)))
+ return
+ end
+ end
+end
+
+--- 步进
+function m.step()
+ local waker = m.delayQueue[m.delayQueueIndex]
+ if waker then
+ m.delayQueue[m.delayQueueIndex] = false
+ m.delayQueueIndex = m.delayQueueIndex + 1
+ local clock = os.clock()
+ waker()
+ local passed = os.clock() - clock
+ if passed > 0.1 then
+ warnStepTime(passed, waker)
+ end
+ return true
+ else
+ m.delayQueue = {}
+ m.delayQueueIndex = 1
+ return false
+ end
+end
+
+function m.setPriority(n)
+ m.coMap[coroutine.running()].priority = true
+end
+
+function m.enable()
+ m._enable = true
+end
+
+function m.disable()
+ m._enable = false
+end
+
+--- 注册事件
+function m.watch(callback)
+ m.watchList[#m.watchList+1] = callback
+end
+
+function m.onWatch(ev, ...)
+ for _, callback in ipairs(m.watchList) do
+ local res = callback(ev, ...)
+ if res ~= nil then
+ return res
+ end
+ end
+end
+
+return m
diff --git a/script/brave/brave.lua b/script/brave/brave.lua
new file mode 100644
index 00000000..08909074
--- /dev/null
+++ b/script/brave/brave.lua
@@ -0,0 +1,70 @@
+local thread = require 'bee.thread'
+
+---@class pub_brave
+local m = {}
+m.type = 'brave'
+m.ability = {}
+m.queue = {}
+
+--- 注册成为勇者
+function m.register(id)
+ m.taskpad = thread.channel('taskpad' .. id)
+ m.waiter = thread.channel('waiter' .. id)
+ m.id = id
+
+ if #m.queue > 0 then
+ for _, info in ipairs(m.queue) do
+ m.waiter:push(info.name, info.params)
+ end
+ end
+ m.queue = nil
+
+ m.start()
+end
+
+--- 注册能力
+function m.on(name, callback)
+ m.ability[name] = callback
+end
+
+--- 报告
+function m.push(name, params)
+ if m.waiter then
+ m.waiter:push(name, params)
+ else
+ m.queue[#m.queue+1] = {
+ name = name,
+ params = params,
+ }
+ end
+end
+
+--- 开始找工作
+function m.start()
+ m.push('mem', collectgarbage 'count')
+ while true do
+ local suc, name, id, params = m.taskpad:pop()
+ if not suc then
+ -- 找不到工作的勇者,只好睡觉
+ thread.sleep(0.001)
+ goto CONTINUE
+ end
+ local ability = m.ability[name]
+ -- TODO
+ if not ability then
+ m.waiter:push(id)
+ log.error('Brave can not handle this work: ' .. name)
+ goto CONTINUE
+ end
+ local ok, res = xpcall(ability, log.error, params)
+ if ok then
+ m.waiter:push(id, res)
+ else
+ m.waiter:push(id)
+ end
+ m.push('mem', collectgarbage 'count')
+ ::CONTINUE::
+ end
+end
+
+return m
diff --git a/script/brave/init.lua b/script/brave/init.lua
new file mode 100644
index 00000000..24c2e412
--- /dev/null
+++ b/script/brave/init.lua
@@ -0,0 +1,4 @@
+local brave = require 'brave.brave'
+require 'brave.work'
+
+return brave
diff --git a/script/brave/log.lua b/script/brave/log.lua
new file mode 100644
index 00000000..18ea7853
--- /dev/null
+++ b/script/brave/log.lua
@@ -0,0 +1,54 @@
+local brave = require 'brave'
+
+local tablePack = table.pack
+local tostring = tostring
+local tableConcat = table.concat
+local debugTraceBack = debug.traceback
+local debugGetInfo = debug.getinfo
+local osClock = os.clock
+
+_ENV = nil
+
+local function pushLog(level, ...)
+ local t = tablePack(...)
+ for i = 1, t.n do
+ t[i] = tostring(t[i])
+ end
+ local str = tableConcat(t, '\t', 1, t.n)
+ if level == 'error' then
+ str = str .. '\n' .. debugTraceBack(nil, 3)
+ end
+ local info = debugGetInfo(3, 'Sl')
+ brave.push('log', {
+ level = level,
+ msg = str,
+ src = info.source,
+ line = info.currentline,
+ clock = osClock(),
+ })
+ return str
+end
+
+local m = {}
+
+function m.info(...)
+ pushLog('info', ...)
+end
+
+function m.debug(...)
+ pushLog('debug', ...)
+end
+
+function m.trace(...)
+ pushLog('trace', ...)
+end
+
+function m.warn(...)
+ pushLog('warn', ...)
+end
+
+function m.error(...)
+ pushLog('error', ...)
+end
+
+return m
diff --git a/script/brave/work.lua b/script/brave/work.lua
new file mode 100644
index 00000000..5ec8178f
--- /dev/null
+++ b/script/brave/work.lua
@@ -0,0 +1,60 @@
+local brave = require 'brave.brave'
+local parser = require 'parser'
+local fs = require 'bee.filesystem'
+local furi = require 'file-uri'
+local util = require 'utility'
+local thread = require 'bee.thread'
+
+brave.on('loadProto', function ()
+ local jsonrpc = require 'jsonrpc'
+ while true do
+ local proto, err = jsonrpc.decode(io.read, log.error)
+ --log.debug('loaded proto', proto.method)
+ if not proto then
+ brave.push('protoerror', err)
+ return
+ end
+ brave.push('proto', proto)
+ thread.sleep(0.001)
+ end
+end)
+
+brave.on('compile', function (text)
+ local state, err = parser:compile(text, 'lua', 'Lua 5.4')
+ if not state then
+ log.error(err)
+ return
+ end
+ local lines = parser:lines(text)
+ return {
+ root = state.root,
+ value = state.value,
+ errs = state.errs,
+ lines = lines,
+ }
+end)
+
+brave.on('listDirectory', function (uri)
+ local path = fs.path(furi.decode(uri))
+ local uris = {}
+ for child in path:list_directory() do
+ local childUri = furi.encode(child:string())
+ uris[#uris+1] = childUri
+ end
+ return uris
+end)
+
+brave.on('isDirectory', function (uri)
+ local path = fs.path(furi.decode(uri))
+ return fs.is_directory(path)
+end)
+
+brave.on('loadFile', function (uri)
+ local filename = furi.decode(uri)
+ return util.loadFile(filename)
+end)
+
+brave.on('saveFile', function (params)
+ local filename = furi.decode(params.uri)
+ return util.saveFile(filename, params.text)
+end)
diff --git a/script/config.lua b/script/config.lua
new file mode 100644
index 00000000..0544c317
--- /dev/null
+++ b/script/config.lua
@@ -0,0 +1,218 @@
+local util = require 'utility'
+local define = require 'proto.define'
+
+local m = {}
+m.version = 0
+
+local function Boolean(v)
+ if type(v) == 'boolean' then
+ return true, v
+ end
+ return false
+end
+
+local function Integer(v)
+ if type(v) == 'number' then
+ return true, math.floor(v)
+ end
+ return false
+end
+
+local function String(v)
+ return true, tostring(v)
+end
+
+local function Str2Hash(sep)
+ return function (v)
+ if type(v) == 'string' then
+ local t = {}
+ for s in v:gmatch('[^'..sep..']+') do
+ t[s] = true
+ end
+ return true, t
+ end
+ if type(v) == 'table' then
+ local t = {}
+ for _, s in ipairs(v) do
+ if type(s) == 'string' then
+ t[s] = true
+ end
+ end
+ return true, t
+ end
+ return false
+ end
+end
+
+local function Array(checker)
+ return function (tbl)
+ if type(tbl) ~= 'table' then
+ return false
+ end
+ local t = {}
+ for _, v in ipairs(tbl) do
+ local ok, result = checker(v)
+ if ok then
+ t[#t+1] = result
+ end
+ end
+ return true, t
+ end
+end
+
+local function Hash(keyChecker, valueChecker)
+ return function (tbl)
+ if type(tbl) ~= 'table' then
+ return false
+ end
+ local t = {}
+ for k, v in pairs(tbl) do
+ local ok1, key = keyChecker(k)
+ local ok2, value = valueChecker(v)
+ if ok1 and ok2 then
+ t[key] = value
+ end
+ end
+ if not next(t) then
+ return false
+ end
+ return true, t
+ end
+end
+
+local function Or(...)
+ local checkers = {...}
+ return function (obj)
+ for _, checker in ipairs(checkers) do
+ local suc, res = checker(obj)
+ if suc then
+ return true, res
+ end
+ end
+ return false
+ end
+end
+
+local ConfigTemplate = {
+ runtime = {
+ version = {'Lua 5.4', String},
+ library = {{}, Str2Hash ';'},
+ path = {{
+ "?.lua",
+ "?/init.lua",
+ "?/?.lua"
+ }, Array(String)},
+ special = {{}, Hash(String, String)},
+ meta = {'${version} ${language}', String},
+ },
+ diagnostics = {
+ enable = {true, Boolean},
+ globals = {{}, Str2Hash ';'},
+ disable = {{}, Str2Hash ';'},
+ severity = {
+ util.deepCopy(define.DiagnosticDefaultSeverity),
+ Hash(String, String),
+ },
+ workspaceDelay = {0, Integer},
+ workspaceRate = {100, Integer},
+ },
+ workspace = {
+ ignoreDir = {{}, Str2Hash ';'},
+ ignoreSubmodules= {true, Boolean},
+ useGitIgnore = {true, Boolean},
+ maxPreload = {1000, Integer},
+ preloadFileSize = {100, Integer},
+ library = {{}, Hash(
+ String,
+ Or(Boolean, Array(String))
+ )}
+ },
+ completion = {
+ enable = {true, Boolean},
+ callSnippet = {'Disable', String},
+ keywordSnippet = {'Replace', String},
+ displayContext = {6, Integer},
+ },
+ signatureHelp = {
+ enable = {true, Boolean},
+ },
+ hover = {
+ enable = {true, Boolean},
+ viewString = {true, Boolean},
+ viewStringMax = {1000, Integer},
+ viewNumber = {true, Boolean},
+ fieldInfer = {3000, Integer},
+ },
+ color = {
+ mode = {'Semantic', String},
+ },
+ luadoc = {
+ enable = {true, Boolean},
+ },
+ plugin = {
+ enable = {false, Boolean},
+ path = {'.vscode/lua-plugin/*.lua', String},
+ },
+ intelliSense = {
+ searchDepth = {0, Integer},
+ fastGlobal = {true, Boolean},
+ },
+}
+
+local OtherTemplate = {
+ associations = {{}, Hash(String, String)},
+ exclude = {{}, Hash(String, Boolean)},
+}
+
+local function init()
+ if m.config then
+ return
+ end
+
+ m.config = {}
+ for c, t in pairs(ConfigTemplate) do
+ m.config[c] = {}
+ for k, info in pairs(t) do
+ m.config[c][k] = info[1]
+ end
+ end
+
+ m.other = {}
+ for k, info in pairs(OtherTemplate) do
+ m.other[k] = info[1]
+ end
+end
+
+function m.setConfig(config, other)
+ m.version = m.version + 1
+ xpcall(function ()
+ for c, t in pairs(config) do
+ for k, v in pairs(t) do
+ local region = ConfigTemplate[c]
+ if region then
+ local info = region[k]
+ local suc, v = info[2](v)
+ if suc then
+ m.config[c][k] = v
+ else
+ m.config[c][k] = info[1]
+ end
+ end
+ end
+ end
+ for k, v in pairs(other) do
+ local info = OtherTemplate[k]
+ local suc, v = info[2](v)
+ if suc then
+ m.other[k] = v
+ else
+ m.other[k] = info[1]
+ end
+ end
+ log.debug('Config update: ', util.dump(m.config), util.dump(m.other))
+ end, log.error)
+end
+
+init()
+
+return m
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
new file mode 100644
index 00000000..69304f98
--- /dev/null
+++ b/script/core/code-action.lua
@@ -0,0 +1,269 @@
+local files = require 'files'
+local lang = require 'language'
+local define = require 'proto.define'
+local guide = require 'parser.guide'
+local util = require 'utility'
+local sp = require 'bee.subprocess'
+
+local function disableDiagnostic(uri, code, results)
+ results[#results+1] = {
+ title = lang.script('ACTION_DISABLE_DIAG', code),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_DISABLE_DIAG,
+ command = 'lua.config',
+ arguments = {
+ {
+ key = 'Lua.diagnostics.disable',
+ action = 'add',
+ value = code,
+ uri = uri,
+ }
+ }
+ }
+ }
+end
+
+local function markGlobal(uri, name, results)
+ results[#results+1] = {
+ title = lang.script('ACTION_MARK_GLOBAL', name),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_MARK_GLOBAL,
+ command = 'lua.config',
+ arguments = {
+ {
+ key = 'Lua.diagnostics.globals',
+ action = 'add',
+ value = name,
+ uri = uri,
+ }
+ }
+ }
+ }
+end
+
+local function changeVersion(uri, version, results)
+ results[#results+1] = {
+ title = lang.script('ACTION_RUNTIME_VERSION', version),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_RUNTIME_VERSION,
+ command = 'lua.config',
+ arguments = {
+ {
+ key = 'Lua.runtime.version',
+ action = 'set',
+ value = version,
+ uri = uri,
+ }
+ }
+ },
+ }
+end
+
+local function solveUndefinedGlobal(uri, diag, results)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ local offset = define.offsetOfWord(lines, text, diag.range.start)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type ~= 'getglobal' then
+ return
+ end
+
+ local name = guide.getName(source)
+ markGlobal(uri, name, results)
+
+ -- TODO check other version
+ end)
+end
+
+local function solveLowercaseGlobal(uri, diag, results)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ local offset = define.offsetOfWord(lines, text, diag.range.start)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type ~= 'setglobal' then
+ return
+ end
+
+ local name = guide.getName(source)
+ markGlobal(uri, name, results)
+ end)
+end
+
+local function findSyntax(uri, diag)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ for _, err in ipairs(ast.errs) do
+ if err.type:lower():gsub('_', '-') == diag.code then
+ local range = define.range(lines, text, err.start, err.finish)
+ if util.equal(range, diag.range) then
+ return err
+ end
+ end
+ end
+ return nil
+end
+
+local function solveSyntaxByChangeVersion(uri, err, results)
+ if type(err.version) == 'table' then
+ for _, version in ipairs(err.version) do
+ changeVersion(uri, version, results)
+ end
+ else
+ changeVersion(uri, err.version, results)
+ end
+end
+
+local function solveSyntaxByAddDoEnd(uri, err, results)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ results[#results+1] = {
+ title = lang.script.ACTION_ADD_DO_END,
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ range = define.range(lines, text, err.start, err.finish),
+ newText = ('do %s end'):format(text:sub(err.start, err.finish)),
+ },
+ }
+ }
+ }
+ }
+end
+
+local function solveSyntaxByFix(uri, err, results)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ local changes = {}
+ for _, fix in ipairs(err.fix) do
+ changes[#changes+1] = {
+ range = define.range(lines, text, fix.start, fix.finish),
+ newText = fix.text,
+ }
+ end
+ results[#results+1] = {
+ title = lang.script['ACTION_' .. err.fix.title],
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = changes,
+ }
+ }
+ }
+end
+
+local function solveSyntax(uri, diag, results)
+ local err = findSyntax(uri, diag)
+ if not err then
+ return
+ end
+ if err.version then
+ solveSyntaxByChangeVersion(uri, err, results)
+ end
+ if err.type == 'ACTION_AFTER_BREAK' or err.type == 'ACTION_AFTER_RETURN' then
+ solveSyntaxByAddDoEnd(uri, err, results)
+ end
+ if err.fix then
+ solveSyntaxByFix(uri, err, results)
+ end
+end
+
+local function solveNewlineCall(uri, diag, results)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ results[#results+1] = {
+ title = lang.script.ACTION_ADD_SEMICOLON,
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ range = {
+ start = diag.range.start,
+ ['end'] = diag.range.start,
+ },
+ newText = ';',
+ }
+ }
+ }
+ }
+ }
+end
+
+local function solveAmbiguity1(uri, diag, results)
+ results[#results+1] = {
+ title = lang.script.ACTION_ADD_BRACKETS,
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_ADD_BRACKETS,
+ command = 'lua.solve:' .. sp:get_id(),
+ arguments = {
+ {
+ name = 'ambiguity-1',
+ uri = uri,
+ range = diag.range,
+ }
+ }
+ },
+ }
+end
+
+local function solveTrailingSpace(uri, diag, results)
+ results[#results+1] = {
+ title = lang.script.ACTION_REMOVE_SPACE,
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_REMOVE_SPACE,
+ command = 'lua.removeSpace:' .. sp:get_id(),
+ arguments = {
+ {
+ uri = uri,
+ }
+ }
+ },
+ }
+end
+
+local function solveDiagnostic(uri, diag, results)
+ if diag.source == lang.script.DIAG_SYNTAX_CHECK then
+ solveSyntax(uri, diag, results)
+ return
+ end
+ if not diag.code then
+ return
+ end
+ if diag.code == 'undefined-global' then
+ solveUndefinedGlobal(uri, diag, results)
+ elseif diag.code == 'lowercase-global' then
+ solveLowercaseGlobal(uri, diag, results)
+ elseif diag.code == 'newline-call' then
+ solveNewlineCall(uri, diag, results)
+ elseif diag.code == 'ambiguity-1' then
+ solveAmbiguity1(uri, diag, results)
+ elseif diag.code == 'trailing-space' then
+ solveTrailingSpace(uri, diag, results)
+ end
+ disableDiagnostic(uri, diag.code, results)
+end
+
+return function (uri, range, diagnostics)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local results = {}
+
+ for _, diag in ipairs(diagnostics) do
+ solveDiagnostic(uri, diag, results)
+ end
+
+ return results
+end
diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua
new file mode 100644
index 00000000..e8b09932
--- /dev/null
+++ b/script/core/command/removeSpace.lua
@@ -0,0 +1,56 @@
+local files = require 'files'
+local define = require 'proto.define'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local lang = require 'language'
+
+local function isInString(ast, offset)
+ return guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'string' then
+ return true
+ end
+ end) or false
+end
+
+return function (data)
+ local uri = data.uri
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local ast = files.getAst(uri)
+ if not lines then
+ return
+ end
+
+ local textEdit = {}
+ for i = 1, #lines do
+ local line = guide.lineContent(lines, text, i, true)
+ local pos = line:find '[ \t]+$'
+ if pos then
+ local start, finish = guide.lineRange(lines, i, true)
+ start = start + pos - 1
+ if isInString(ast, start) then
+ goto NEXT_LINE
+ end
+ textEdit[#textEdit+1] = {
+ range = define.range(lines, text, start, finish),
+ newText = '',
+ }
+ goto NEXT_LINE
+ end
+
+ ::NEXT_LINE::
+ end
+
+ if #textEdit == 0 then
+ return
+ end
+
+ proto.awaitRequest('workspace/applyEdit', {
+ label = lang.script.COMMAND_REMOVE_SPACE,
+ edit = {
+ changes = {
+ [uri] = textEdit,
+ }
+ },
+ })
+end
diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua
new file mode 100644
index 00000000..d3b8f94e
--- /dev/null
+++ b/script/core/command/solve.lua
@@ -0,0 +1,96 @@
+local files = require 'files'
+local define = require 'proto.define'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local lang = require 'language'
+
+local opMap = {
+ ['+'] = true,
+ ['-'] = true,
+ ['*'] = true,
+ ['/'] = true,
+ ['//'] = true,
+ ['^'] = true,
+ ['<<'] = true,
+ ['>>'] = true,
+ ['&'] = true,
+ ['|'] = true,
+ ['~'] = true,
+ ['..'] = true,
+}
+
+local literalMap = {
+ ['number'] = true,
+ ['boolean'] = true,
+ ['string'] = true,
+ ['table'] = true,
+}
+
+return function (data)
+ local uri = data.uri
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local start = define.offsetOfWord(lines, text, data.range.start)
+ local finish = define.offsetOfWord(lines, text, data.range['end'])
+
+ local result = guide.eachSourceContain(ast.ast, start, function (source)
+ if source.start ~= start
+ or source.finish ~= finish then
+ return
+ end
+ if not source.op or source.op.type ~= 'or' then
+ return
+ end
+ local first = source[1]
+ local second = source[2]
+ -- a + b or 0 --> a + (b or 0)
+ do
+ if first.op
+ and opMap[first.op.type]
+ and first.type ~= 'unary'
+ and not second.op
+ and literalMap[second.type] then
+ return {
+ start = source[1][2].start,
+ finish = source[2].finish,
+ }
+ end
+ end
+ -- a or b + c --> (a or b) + c
+ do
+ if second.op
+ and opMap[second.op.type]
+ and second.type ~= 'unary'
+ and not first.op
+ and literalMap[second[1].type] then
+ return {
+ start = source[1].start,
+ finish = source[2][1].finish,
+ }
+ end
+ end
+ end)
+
+ if not result then
+ return
+ end
+
+ proto.awaitRequest('workspace/applyEdit', {
+ label = lang.script.COMMAND_REMOVE_SPACE,
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ range = define.range(lines, text, result.start, result.finish),
+ newText = ('(%s)'):format(text:sub(result.start, result.finish)),
+ }
+ },
+ }
+ },
+ })
+end
diff --git a/script/core/completion.lua b/script/core/completion.lua
new file mode 100644
index 00000000..44874b39
--- /dev/null
+++ b/script/core/completion.lua
@@ -0,0 +1,1284 @@
+local define = require 'proto.define'
+local files = require 'files'
+local guide = require 'parser.guide'
+local matchKey = require 'core.matchkey'
+local vm = require 'vm'
+local getLabel = require 'core.hover.label'
+local getName = require 'core.hover.name'
+local getArg = require 'core.hover.arg'
+local getDesc = require 'core.hover.description'
+local getHover = require 'core.hover'
+local config = require 'config'
+local util = require 'utility'
+local markdown = require 'provider.markdown'
+local findSource = require 'core.find-source'
+local await = require 'await'
+local parser = require 'parser'
+local keyWordMap = require 'core.keyword'
+local workspace = require 'workspace'
+local furi = require 'file-uri'
+local rpath = require 'workspace.require-path'
+local lang = require 'language'
+
+local stackID = 0
+local stacks = {}
+local function stack(callback)
+ stackID = stackID + 1
+ stacks[stackID] = callback
+ return stackID
+end
+
+local function clearStack()
+ stacks = {}
+end
+
+local function resolveStack(id)
+ local callback = stacks[id]
+ if not callback then
+ return nil
+ end
+
+ -- 当进行新的 resolve 时,放弃当前的 resolve
+ await.close('completion.resove')
+ return await.await(callback, 'completion.resove')
+end
+
+local function trim(str)
+ return str:match '^%s*(%S+)%s*$'
+end
+
+local function isSpace(char)
+ if char == ' '
+ or char == '\n'
+ or char == '\r'
+ or char == '\t' then
+ return true
+ end
+ return false
+end
+
+local function skipSpace(text, offset)
+ for i = offset, 1, -1 do
+ local char = text:sub(i, i)
+ if not isSpace(char) then
+ return i
+ end
+ end
+ return 0
+end
+
+local function findWord(text, offset)
+ for i = offset, 1, -1 do
+ if not text:sub(i, i):match '[%w_]' then
+ if i == offset then
+ return nil
+ end
+ return text:sub(i+1, offset), i+1
+ end
+ end
+ return text:sub(1, offset), 1
+end
+
+local function findSymbol(text, offset)
+ for i = offset, 1, -1 do
+ local char = text:sub(i, i)
+ if isSpace(char) then
+ goto CONTINUE
+ end
+ if char == '.'
+ or char == ':'
+ or char == '(' then
+ return char, i
+ else
+ return nil
+ end
+ ::CONTINUE::
+ end
+ return nil
+end
+
+local function findAnyPos(text, offset)
+ for i = offset, 1, -1 do
+ if not isSpace(text:sub(i, i)) then
+ return i
+ end
+ end
+ return nil
+end
+
+local function findParent(ast, text, offset)
+ for i = offset, 1, -1 do
+ local char = text:sub(i, i)
+ if isSpace(char) then
+ goto CONTINUE
+ end
+ local oop
+ if char == '.' then
+ -- `..` 的情况
+ if text:sub(i-1, i-1) == '.' then
+ return nil, nil
+ end
+ oop = false
+ elseif char == ':' then
+ oop = true
+ else
+ return nil, nil
+ end
+ local anyPos = findAnyPos(text, i-1)
+ if not anyPos then
+ return nil, nil
+ end
+ local parent = guide.eachSourceContain(ast.ast, anyPos, function (source)
+ if source.finish == anyPos then
+ return source
+ end
+ end)
+ if parent then
+ return parent, oop
+ end
+ ::CONTINUE::
+ end
+ return nil, nil
+end
+
+local function findParentInStringIndex(ast, text, offset)
+ local near, nearStart
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ local start = guide.getStartFinish(source)
+ if not start then
+ return
+ end
+ if not nearStart or nearStart < start then
+ near = source
+ nearStart = start
+ end
+ end)
+ if not near or near.type ~= 'string' then
+ return
+ end
+ local parent = near.parent
+ if not parent or parent.index ~= near then
+ return
+ end
+ -- index不可能是oop模式
+ return parent.node, false
+end
+
+local function buildFunctionSnip(source, oop)
+ local name = getName(source):gsub('^.-[$.:]', '')
+ local defs = vm.getDefs(source, 'deep')
+ local args = ''
+ for _, def in ipairs(defs) do
+ local defArgs = getArg(def, oop)
+ if defArgs ~= '' then
+ args = defArgs
+ break
+ end
+ end
+ local id = 0
+ args = args:gsub('[^,]+', function (arg)
+ id = id + 1
+ return arg:gsub('^(%s*)(.+)', function (sp, word)
+ return ('%s${%d:%s}'):format(sp, id, word)
+ end)
+ end)
+ return ('%s(%s)'):format(name, args)
+end
+
+local function buildDetail(source)
+ local types = vm.getInferType(source, 'deep')
+ local literals = vm.getInferLiteral(source, 'deep')
+ if literals then
+ return types .. ' = ' .. literals
+ else
+ return types
+ end
+end
+
+local function getSnip(source)
+ local context = config.config.completion.displayContext
+ if context <= 0 then
+ return nil
+ end
+ local defs = vm.getRefs(source, 'deep')
+ for _, def in ipairs(defs) do
+ def = guide.getObjectValue(def) or def
+ if def ~= source and def.type == 'function' then
+ local uri = guide.getUri(def)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ if not text then
+ goto CONTINUE
+ end
+ if vm.isMetaFile(uri) then
+ goto CONTINUE
+ end
+ local row = guide.positionOf(lines, def.start)
+ local firstRow = lines[row]
+ local lastRow = lines[math.min(row + context - 1, #lines)]
+ local snip = text:sub(firstRow.start, lastRow.finish)
+ return snip
+ end
+ ::CONTINUE::
+ end
+end
+
+local function buildDesc(source)
+ local hover = getHover.get(source)
+ local md = markdown()
+ md:add('lua', hover.label)
+ md:add('md', hover.description)
+ local snip = getSnip(source)
+ if snip then
+ md:add('md', '-------------')
+ md:add('lua', snip)
+ end
+ return md:string()
+end
+
+local function buildFunction(results, source, oop, data)
+ local snipType = config.config.completion.callSnippet
+ if snipType == 'Disable' or snipType == 'Both' then
+ results[#results+1] = data
+ end
+ if snipType == 'Both' or snipType == 'Replace' then
+ local snipData = util.deepCopy(data)
+ snipData.kind = define.CompletionItemKind.Snippet
+ snipData.label = snipData.label .. '()'
+ snipData.insertText = buildFunctionSnip(source, oop)
+ snipData.insertTextFormat = 2
+ snipData.id = stack(function ()
+ return {
+ detail = buildDetail(source),
+ description = buildDesc(source),
+ }
+ end)
+ results[#results+1] = snipData
+ end
+end
+
+local function isSameSource(ast, source, pos)
+ if not files.eq(guide.getUri(source), guide.getUri(ast.ast)) then
+ return false
+ end
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ return source.start <= pos and source.finish >= pos
+end
+
+local function checkLocal(ast, word, offset, results)
+ local locals = guide.getVisibleLocals(ast.ast, offset)
+ for name, source in pairs(locals) do
+ if isSameSource(ast, source, offset) then
+ goto CONTINUE
+ end
+ if not matchKey(word, name) then
+ goto CONTINUE
+ end
+ if vm.hasType(source, 'function') then
+ buildFunction(results, source, false, {
+ label = name,
+ kind = define.CompletionItemKind.Function,
+ id = stack(function ()
+ return {
+ detail = buildDetail(source),
+ description = buildDesc(source),
+ }
+ end),
+ })
+ else
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Variable,
+ id = stack(function ()
+ return {
+ detail = buildDetail(source),
+ description = buildDesc(source),
+ }
+ end),
+ }
+ end
+ ::CONTINUE::
+ end
+end
+
+local function checkFieldFromFieldToIndex(name, parent, word, start, offset)
+ if name:match '^[%a_][%w_]*$' then
+ return nil
+ end
+ local textEdit, additionalTextEdits
+ local uri = guide.getUri(parent)
+ local text = files.getText(uri)
+ local wordStart
+ if word == '' then
+ wordStart = text:match('()%S', start + 1) or (offset + 1)
+ else
+ wordStart = offset - #word + 1
+ end
+ textEdit = {
+ start = wordStart,
+ finish = offset,
+ newText = ('[%q]'):format(name),
+ }
+ local nxt = parent.next
+ if nxt then
+ local dotStart
+ if nxt.type == 'setfield'
+ or nxt.type == 'getfield'
+ or nxt.type == 'tablefield' then
+ dotStart = nxt.dot.start
+ elseif nxt.type == 'setmethod'
+ or nxt.type == 'getmethod' then
+ dotStart = nxt.colon.start
+ end
+ if dotStart then
+ additionalTextEdits = {
+ {
+ start = dotStart,
+ finish = dotStart,
+ newText = '',
+ }
+ }
+ end
+ else
+ if config.config.runtime.version == 'Lua 5.1'
+ or config.config.runtime.version == 'LuaJIT' then
+ textEdit.newText = '_G' .. textEdit.newText
+ else
+ textEdit.newText = '_ENV' .. textEdit.newText
+ end
+ end
+ return textEdit, additionalTextEdits
+end
+
+local function checkFieldThen(name, src, word, start, offset, parent, oop, results)
+ local value = guide.getObjectValue(src) or src
+ local kind = define.CompletionItemKind.Field
+ if value.type == 'function' then
+ if oop then
+ kind = define.CompletionItemKind.Method
+ else
+ kind = define.CompletionItemKind.Function
+ end
+ buildFunction(results, src, oop, {
+ label = name,
+ kind = kind,
+ deprecated = vm.isDeprecated(src) or nil,
+ id = stack(function ()
+ return {
+ detail = buildDetail(src),
+ description = buildDesc(src),
+ }
+ end),
+ })
+ return
+ end
+ if oop then
+ return
+ end
+ local literal = guide.getLiteral(value)
+ if literal ~= nil then
+ kind = define.CompletionItemKind.Enum
+ end
+ local textEdit, additionalTextEdits
+ if parent.next and parent.next.index then
+ local str = parent.next.index
+ textEdit = {
+ start = str.start + #str[2],
+ finish = offset,
+ newText = name,
+ }
+ else
+ textEdit, additionalTextEdits = checkFieldFromFieldToIndex(name, parent, word, start, offset)
+ end
+ results[#results+1] = {
+ label = name,
+ kind = kind,
+ textEdit = textEdit,
+ additionalTextEdits = additionalTextEdits,
+ id = stack(function ()
+ return {
+ detail = buildDetail(src),
+ description = buildDesc(src),
+ }
+ end)
+ }
+end
+
+local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, isGlobal)
+ local fields = {}
+ local count = 0
+ for _, src in ipairs(refs) do
+ local key = vm.getKeyName(src)
+ if not key or key:sub(1, 1) ~= 's' then
+ goto CONTINUE
+ end
+ if isSameSource(ast, src, start) then
+ -- 由于fastGlobal的优化,全局变量只会找出一个值,有可能找出自己
+ -- 所以遇到自己的时候重新找一下有没有其他定义
+ if not isGlobal then
+ goto CONTINUE
+ end
+ if #vm.getGlobals(key) <= 1 then
+ goto CONTINUE
+ elseif not vm.isSet(src) then
+ src = vm.getGlobalSets(key)[1] or src
+ end
+ end
+ local name = key:sub(3)
+ if locals and locals[name] then
+ goto CONTINUE
+ end
+ if not matchKey(word, name, count >= 100) then
+ goto CONTINUE
+ end
+ local last = fields[name]
+ if not last then
+ fields[name] = src
+ count = count + 1
+ goto CONTINUE
+ end
+ if src.type == 'tablefield'
+ or src.type == 'setfield'
+ or src.type == 'tableindex'
+ or src.type == 'setindex'
+ or src.type == 'setmethod'
+ or src.type == 'setglobal' then
+ fields[name] = src
+ goto CONTINUE
+ end
+ ::CONTINUE::
+ end
+ for name, src in util.sortPairs(fields) do
+ checkFieldThen(name, src, word, start, offset, parent, oop, results)
+ end
+end
+
+local function checkField(ast, word, start, offset, parent, oop, results)
+ local refs = vm.getFields(parent, 'deep')
+ checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results)
+end
+
+local function checkGlobal(ast, word, start, offset, parent, oop, results)
+ local locals = guide.getVisibleLocals(ast.ast, offset)
+ local refs = vm.getGlobalSets '*'
+ checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global')
+end
+
+local function checkTableField(ast, word, start, results)
+ local source = guide.eachSourceContain(ast.ast, start, function (source)
+ if source.start == start
+ and source.parent
+ and source.parent.type == 'table' then
+ return source
+ end
+ end)
+ if not source then
+ return
+ end
+ local used = {}
+ guide.eachSourceType(ast.ast, 'tablefield', function (src)
+ if not src.field then
+ return
+ end
+ local key = src.field[1]
+ if not used[key]
+ and matchKey(word, key)
+ and src ~= source then
+ used[key] = true
+ results[#results+1] = {
+ label = key,
+ kind = define.CompletionItemKind.Property,
+ }
+ end
+ end)
+end
+
+local function checkCommon(word, text, offset, results)
+ local used = {}
+ for _, result in ipairs(results) do
+ used[result.label] = true
+ end
+ for _, data in ipairs(keyWordMap) do
+ used[data[1]] = true
+ end
+ for str, pos in text:gmatch '([%a_][%w_]*)()' do
+ if not used[str] and pos - 1 ~= offset then
+ used[str] = true
+ if matchKey(word, str) then
+ results[#results+1] = {
+ label = str,
+ kind = define.CompletionItemKind.Text,
+ }
+ end
+ end
+ end
+end
+
+local function isInString(ast, offset)
+ return guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'string' then
+ return true
+ end
+ end)
+end
+
+local function checkKeyWord(ast, text, start, word, hasSpace, afterLocal, results)
+ local snipType = config.config.completion.keywordSnippet
+ for _, data in ipairs(keyWordMap) do
+ local key = data[1]
+ local eq
+ if hasSpace then
+ eq = word == key
+ else
+ eq = matchKey(word, key)
+ end
+ if afterLocal and key ~= 'function' then
+ eq = false
+ end
+ if eq then
+ local replaced
+ local extra
+ if snipType == 'Both' or snipType == 'Replace' then
+ local func = data[2]
+ if func then
+ replaced = func(hasSpace, results)
+ extra = true
+ end
+ end
+ if snipType == 'Both' then
+ replaced = false
+ end
+ if not replaced then
+ if not hasSpace then
+ local item = {
+ label = key,
+ kind = define.CompletionItemKind.Keyword,
+ }
+ if extra then
+ table.insert(results, #results, item)
+ else
+ results[#results+1] = item
+ end
+ end
+ end
+ local checkStop = data[3]
+ if checkStop then
+ local stop = checkStop(ast, start)
+ if stop then
+ return true
+ end
+ end
+ end
+ end
+end
+
+local function checkProvideLocal(ast, word, start, results)
+ local block
+ guide.eachSourceContain(ast.ast, start, function (source)
+ if source.type == 'function'
+ or source.type == 'main' then
+ block = source
+ end
+ end)
+ if not block then
+ return
+ end
+ local used = {}
+ guide.eachSourceType(block, 'getglobal', function (source)
+ if source.start > start
+ and not used[source[1]]
+ and matchKey(word, source[1]) then
+ used[source[1]] = true
+ results[#results+1] = {
+ label = source[1],
+ kind = define.CompletionItemKind.Variable,
+ }
+ end
+ end)
+ guide.eachSourceType(block, 'getlocal', function (source)
+ if source.start > start
+ and not used[source[1]]
+ and matchKey(word, source[1]) then
+ used[source[1]] = true
+ results[#results+1] = {
+ label = source[1],
+ kind = define.CompletionItemKind.Variable,
+ }
+ end
+ end)
+end
+
+local function checkFunctionArgByDocParam(ast, word, start, results)
+ local func = guide.eachSourceContain(ast.ast, start, function (source)
+ if source.type == 'function' then
+ return source
+ end
+ end)
+ if not func then
+ return
+ end
+ local docs = func.bindDocs
+ if not docs then
+ return
+ end
+ local params = {}
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.param' then
+ params[#params+1] = doc
+ end
+ end
+ local firstArg = func.args and func.args[1]
+ if not firstArg
+ or firstArg.start <= start and firstArg.finish >= start then
+ local firstParam = params[1]
+ if firstParam and matchKey(word, firstParam.param[1]) then
+ local label = {}
+ for _, param in ipairs(params) do
+ label[#label+1] = param.param[1]
+ end
+ results[#results+1] = {
+ label = table.concat(label, ', '),
+ kind = define.CompletionItemKind.Snippet,
+ }
+ end
+ end
+ for _, doc in ipairs(params) do
+ if matchKey(word, doc.param[1]) then
+ results[#results+1] = {
+ label = doc.param[1],
+ kind = define.CompletionItemKind.Interface,
+ }
+ end
+ end
+end
+
+local function isAfterLocal(text, start)
+ local pos = skipSpace(text, start-1)
+ local word = findWord(text, pos)
+ return word == 'local'
+end
+
+local function checkUri(ast, text, offset, results)
+ local collect = {}
+ local myUri = guide.getUri(ast.ast)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type ~= 'string' then
+ return
+ end
+ local callargs = source.parent
+ if not callargs or callargs.type ~= 'callargs' then
+ return
+ end
+ if callargs[1] ~= source then
+ return
+ end
+ local call = callargs.parent
+ local func = call.node
+ local literal = guide.getLiteral(source)
+ local libName = vm.getLibraryName(func)
+ if not libName then
+ return
+ end
+ if libName == 'require' then
+ for uri in files.eachFile() do
+ uri = files.getOriginUri(uri)
+ if files.eq(myUri, uri) then
+ goto CONTINUE
+ end
+ if vm.isMetaFile(uri) then
+ goto CONTINUE
+ end
+ local path = workspace.getRelativePath(uri)
+ local infos = rpath.getVisiblePath(path, config.config.runtime.path)
+ for _, info in ipairs(infos) do
+ if matchKey(literal, info.expect) then
+ if not collect[info.expect] then
+ collect[info.expect] = {
+ textEdit = {
+ start = source.start + #source[2],
+ finish = source.finish - #source[2],
+ }
+ }
+ end
+ collect[info.expect][#collect[info.expect]+1] = ([=[* [%s](%s) %s]=]):format(
+ path,
+ uri,
+ lang.script('HOVER_USE_LUA_PATH', info.searcher)
+ )
+ end
+ end
+ ::CONTINUE::
+ end
+ elseif libName == 'dofile'
+ or libName == 'loadfile' then
+ for uri in files.eachFile() do
+ uri = files.getOriginUri(uri)
+ if files.eq(myUri, uri) then
+ goto CONTINUE
+ end
+ if vm.isMetaFile(uri) then
+ goto CONTINUE
+ end
+ local path = workspace.getRelativePath(uri)
+ if matchKey(literal, path) then
+ if not collect[path] then
+ collect[path] = {
+ textEdit = {
+ start = source.start + #source[2],
+ finish = source.finish - #source[2],
+ }
+ }
+ end
+ collect[path][#collect[path]+1] = ([=[[%s](%s)]=]):format(
+ path,
+ uri
+ )
+ end
+ ::CONTINUE::
+ end
+ end
+ end)
+ for label, infos in util.sortPairs(collect) do
+ local mark = {}
+ local des = {}
+ for _, info in ipairs(infos) do
+ if not mark[info] then
+ mark[info] = true
+ des[#des+1] = info
+ end
+ end
+ results[#results+1] = {
+ label = label,
+ kind = define.CompletionItemKind.Reference,
+ description = table.concat(des, '\n'),
+ textEdit = infos.textEdit,
+ }
+ end
+end
+
+local function checkLenPlusOne(ast, text, offset, results)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'getindex'
+ or source.type == 'setindex' then
+ local _, pos = text:find('%s*%[%s*%#', source.node.finish)
+ if not pos then
+ return
+ end
+ local nodeText = text:sub(source.node.start, source.node.finish)
+ local writingText = trim(text:sub(pos + 1, offset - 1)) or ''
+ if not matchKey(writingText, nodeText) then
+ return
+ end
+ if source.parent == guide.getParentBlock(source) then
+ -- state
+ local label = text:match('%#[ \t]*', pos) .. nodeText .. '+1'
+ local eq = text:find('^%s*%]?%s*%=', source.finish)
+ local newText = label .. ']'
+ if not eq then
+ newText = newText .. ' = '
+ end
+ results[#results+1] = {
+ label = label,
+ kind = define.CompletionItemKind.Snippet,
+ textEdit = {
+ start = pos,
+ finish = source.finish,
+ newText = newText,
+ },
+ }
+ else
+ -- exp
+ local label = text:match('%#[ \t]*', pos) .. nodeText
+ local newText = label .. ']'
+ results[#results+1] = {
+ label = label,
+ kind = define.CompletionItemKind.Snippet,
+ textEdit = {
+ start = pos,
+ finish = source.finish,
+ newText = newText,
+ },
+ }
+ end
+ end
+ end)
+end
+
+local function isFuncArg(ast, offset)
+ return guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'funcargs' then
+ return true
+ end
+ end)
+end
+
+local function trySpecial(ast, text, offset, results)
+ if isInString(ast, offset) then
+ checkUri(ast, text, offset, results)
+ return
+ end
+ -- x[#x+1]
+ checkLenPlusOne(ast, text, offset, results)
+end
+
+local function tryIndex(ast, text, offset, results)
+ local parent, oop = findParentInStringIndex(ast, text, offset)
+ if not parent then
+ return
+ end
+ local word = parent.next.index[1]
+ checkField(ast, word, offset, offset, parent, oop, results)
+end
+
+local function tryWord(ast, text, offset, results)
+ local finish = skipSpace(text, offset)
+ local word, start = findWord(text, finish)
+ if not word then
+ return nil
+ end
+ local hasSpace = finish ~= offset
+ if isInString(ast, offset) then
+ else
+ local parent, oop = findParent(ast, text, start - 1)
+ if parent then
+ if not hasSpace then
+ checkField(ast, word, start, offset, parent, oop, results)
+ end
+ elseif isFuncArg(ast, offset) then
+ checkProvideLocal(ast, word, start, results)
+ checkFunctionArgByDocParam(ast, word, start, results)
+ else
+ local afterLocal = isAfterLocal(text, start)
+ local stop = checkKeyWord(ast, text, start, word, hasSpace, afterLocal, results)
+ if stop then
+ return
+ end
+ if not hasSpace then
+ if afterLocal then
+ checkProvideLocal(ast, word, start, results)
+ else
+ checkLocal(ast, word, start, results)
+ checkTableField(ast, word, start, results)
+ local env = guide.getENV(ast.ast, start)
+ checkGlobal(ast, word, start, offset, env, false, results)
+ end
+ end
+ end
+ if not hasSpace then
+ checkCommon(word, text, offset, results)
+ end
+ end
+end
+
+local function trySymbol(ast, text, offset, results)
+ local symbol, start = findSymbol(text, offset)
+ if not symbol then
+ return nil
+ end
+ if isInString(ast, offset) then
+ return nil
+ end
+ if symbol == '.'
+ or symbol == ':' then
+ local parent, oop = findParent(ast, text, start)
+ if parent then
+ checkField(ast, '', start, offset, parent, oop, results)
+ end
+ end
+ if symbol == '(' then
+ checkFunctionArgByDocParam(ast, '', start, results)
+ end
+end
+
+local function getCallEnums(source, index)
+ if source.type == 'function' and source.bindDocs then
+ if not source.args then
+ return
+ end
+ local arg
+ if index <= #source.args then
+ arg = source.args[index]
+ else
+ local lastArg = source.args[#source.args]
+ if lastArg.type == '...' then
+ arg = lastArg
+ else
+ return
+ end
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.param'
+ and doc.param[1] == arg[1] then
+ local enums = {}
+ for _, enum in ipairs(vm.getDocEnums(doc.extends)) do
+ enums[#enums+1] = {
+ label = enum[1],
+ description = enum.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ end
+ return enums
+ elseif doc.type == 'doc.vararg'
+ and arg.type == '...' then
+ local enums = {}
+ for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do
+ enums[#enums+1] = {
+ label = enum[1],
+ description = enum.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ end
+ return enums
+ end
+ end
+ end
+end
+
+local function tryLabelInString(label, arg)
+ if not arg or arg.type ~= 'string' then
+ return label
+ end
+ local str = parser:grammar(label, 'String')
+ if not str then
+ return label
+ end
+ if not matchKey(arg[1], str[1]) then
+ return nil
+ end
+ return util.viewString(str[1], arg[2])
+end
+
+local function mergeEnums(a, b, text, arg)
+ local mark = {}
+ for _, enum in ipairs(a) do
+ mark[enum.label] = true
+ end
+ for _, enum in ipairs(b) do
+ local label = tryLabelInString(enum.label, arg)
+ if label and not mark[label] then
+ mark[label] = true
+ local result = {
+ label = label,
+ kind = define.CompletionItemKind.EnumMember,
+ description = enum.description,
+ textEdit = arg and {
+ start = arg.start,
+ finish = arg.finish,
+ newText = label,
+ },
+ }
+ a[#a+1] = result
+ end
+ end
+end
+
+local function findCall(ast, text, offset)
+ local call
+ guide.eachSourceContain(ast.ast, offset, function (src)
+ if src.type == 'call' then
+ if not call or call.start < src.start then
+ call = src
+ end
+ end
+ end)
+ return call
+end
+
+local function getCallArgInfo(call, text, offset)
+ if not call.args then
+ return 1, nil
+ end
+ for index, arg in ipairs(call.args) do
+ if arg.start <= offset and arg.finish >= offset then
+ return index, arg
+ end
+ end
+ return #call.args + 1, nil
+end
+
+local function tryCallArg(ast, text, offset, results)
+ local call = findCall(ast, text, offset)
+ if not call then
+ return
+ end
+ local myResults = {}
+ local argIndex, arg = getCallArgInfo(call, text, offset)
+ if arg and arg.type == 'function' then
+ return
+ end
+ local defs = vm.getDefs(call.node, 'deep')
+ for _, def in ipairs(defs) do
+ def = guide.getObjectValue(def) or def
+ local enums = getCallEnums(def, argIndex)
+ if enums then
+ mergeEnums(myResults, enums, text, arg)
+ end
+ end
+ for _, enum in ipairs(myResults) do
+ results[#results+1] = enum
+ end
+end
+
+local function getComment(ast, offset)
+ for _, comm in ipairs(ast.comms) do
+ if offset >= comm.start and offset <= comm.finish then
+ return comm
+ end
+ end
+ return nil
+end
+
+local function tryLuaDocCate(line, results)
+ local word = line:sub(3)
+ for _, docType in ipairs {
+ 'class',
+ 'type',
+ 'alias',
+ 'param',
+ 'return',
+ 'field',
+ 'generic',
+ 'vararg',
+ 'overload',
+ 'deprecated',
+ 'meta',
+ 'version',
+ } do
+ if matchKey(word, docType) then
+ results[#results+1] = {
+ label = docType,
+ kind = define.CompletionItemKind.Event,
+ }
+ end
+ end
+end
+
+local function getLuaDocByContain(ast, offset)
+ local result
+ local range = math.huge
+ guide.eachSourceContain(ast.ast.docs, offset, function (src)
+ if not src.start then
+ return
+ end
+ if range >= offset - src.start
+ and offset <= src.finish then
+ range = offset - src.start
+ result = src
+ end
+ end)
+ return result
+end
+
+local function getLuaDocByErr(ast, text, start, offset)
+ local targetError
+ for _, err in ipairs(ast.errs) do
+ if err.finish <= offset
+ and err.start >= start then
+ if not text:sub(err.finish + 1, offset):find '%S' then
+ targetError = err
+ break
+ end
+ end
+ end
+ if not targetError then
+ return nil
+ end
+ local targetDoc
+ for i = #ast.ast.docs, 1, -1 do
+ local doc = ast.ast.docs[i]
+ if doc.finish <= targetError.start then
+ targetDoc = doc
+ break
+ end
+ end
+ return targetError, targetDoc
+end
+
+local function tryLuaDocBySource(ast, offset, source, results)
+ if source.type == 'doc.extends.name' then
+ if source.parent.type == 'doc.class' then
+ for _, doc in ipairs(vm.getDocTypes '*') do
+ if doc.type == 'doc.class.name'
+ and doc.parent ~= source.parent
+ and matchKey(source[1], doc[1]) then
+ results[#results+1] = {
+ label = doc[1],
+ kind = define.CompletionItemKind.Class,
+ }
+ end
+ end
+ end
+ elseif source.type == 'doc.type.name' then
+ for _, doc in ipairs(vm.getDocTypes '*') do
+ if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name')
+ and doc.parent ~= source.parent
+ and matchKey(source[1], doc[1]) then
+ results[#results+1] = {
+ label = doc[1],
+ kind = define.CompletionItemKind.Class,
+ }
+ end
+ end
+ elseif source.type == 'doc.param.name' then
+ local funcs = {}
+ guide.eachSourceBetween(ast.ast, offset, math.huge, function (src)
+ if src.type == 'function' and src.start > offset then
+ funcs[#funcs+1] = src
+ end
+ end)
+ table.sort(funcs, function (a, b)
+ return a.start < b.start
+ end)
+ local func = funcs[1]
+ if not func or not func.args then
+ return
+ end
+ for _, arg in ipairs(func.args) do
+ if arg[1] and matchKey(source[1], arg[1]) then
+ results[#results+1] = {
+ label = arg[1],
+ kind = define.CompletionItemKind.Interface,
+ }
+ end
+ end
+ end
+end
+
+local function tryLuaDocByErr(ast, offset, err, docState, results)
+ if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then
+ for _, doc in ipairs(vm.getDocTypes '*') do
+ if doc.type == 'doc.class.name'
+ and doc.parent ~= docState then
+ results[#results+1] = {
+ label = doc[1],
+ kind = define.CompletionItemKind.Class,
+ }
+ end
+ end
+ elseif err.type == 'LUADOC_MISS_TYPE_NAME' then
+ for _, doc in ipairs(vm.getDocTypes '*') do
+ if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then
+ results[#results+1] = {
+ label = doc[1],
+ kind = define.CompletionItemKind.Class,
+ }
+ end
+ end
+ elseif err.type == 'LUADOC_MISS_PARAM_NAME' then
+ local funcs = {}
+ guide.eachSourceBetween(ast.ast, offset, math.huge, function (src)
+ if src.type == 'function' and src.start > offset then
+ funcs[#funcs+1] = src
+ end
+ end)
+ table.sort(funcs, function (a, b)
+ return a.start < b.start
+ end)
+ local func = funcs[1]
+ if not func or not func.args then
+ return
+ end
+ local label = {}
+ local insertText = {}
+ for i, arg in ipairs(func.args) do
+ if arg[1] then
+ label[#label+1] = arg[1]
+ if i == 1 then
+ insertText[i] = ('%s ${%d:any}'):format(arg[1], i)
+ else
+ insertText[i] = ('---@param %s ${%d:any}'):format(arg[1], i)
+ end
+ end
+ end
+ results[#results+1] = {
+ label = table.concat(label, ', '),
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = table.concat(insertText, '\n'),
+ }
+ for i, arg in ipairs(func.args) do
+ if arg[1] then
+ results[#results+1] = {
+ label = arg[1],
+ kind = define.CompletionItemKind.Interface,
+ }
+ end
+ end
+ end
+end
+
+local function tryLuaDocFeatures(line, ast, comm, offset, results)
+end
+
+local function tryLuaDoc(ast, text, offset, results)
+ local comm = getComment(ast, offset)
+ local line = text:sub(comm.start, offset)
+ if not line then
+ return
+ end
+ if line:sub(1, 2) ~= '-@' then
+ return
+ end
+ -- 尝试 ---@$
+ local cate = line:match('%a*', 3)
+ if #cate + 2 >= #line then
+ tryLuaDocCate(line, results)
+ return
+ end
+ -- 尝试一些其他特征
+ if tryLuaDocFeatures(line, ast, comm, offset, results) then
+ return
+ end
+ -- 根据输入中的source来补全
+ local source = getLuaDocByContain(ast, offset)
+ if source then
+ tryLuaDocBySource(ast, offset, source, results)
+ return
+ end
+ -- 根据附近的错误消息来补全
+ local err, doc = getLuaDocByErr(ast, text, comm.start, offset)
+ if err then
+ tryLuaDocByErr(ast, offset, err, doc, results)
+ return
+ end
+end
+
+local function completion(uri, offset)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
+ local results = {}
+ clearStack()
+ if ast then
+ if getComment(ast, offset) then
+ tryLuaDoc(ast, text, offset, results)
+ else
+ trySpecial(ast, text, offset, results)
+ tryWord(ast, text, offset, results)
+ tryIndex(ast, text, offset, results)
+ trySymbol(ast, text, offset, results)
+ tryCallArg(ast, text, offset, results)
+ end
+ else
+ local word = findWord(text, offset)
+ if word then
+ checkCommon(word, text, offset, results)
+ end
+ end
+
+ if #results == 0 then
+ return nil
+ end
+ return results
+end
+
+local function resolve(id)
+ return resolveStack(id)
+end
+
+return {
+ completion = completion,
+ resolve = resolve,
+}
diff --git a/script/core/definition.lua b/script/core/definition.lua
new file mode 100644
index 00000000..c143939d
--- /dev/null
+++ b/script/core/definition.lua
@@ -0,0 +1,156 @@
+local guide = require 'parser.guide'
+local workspace = require 'workspace'
+local files = require 'files'
+local vm = require 'vm'
+local findSource = require 'core.find-source'
+
+local function sortResults(results)
+ -- 先按照顺序排序
+ table.sort(results, function (a, b)
+ local u1 = guide.getUri(a.target)
+ local u2 = guide.getUri(b.target)
+ if u1 == u2 then
+ return a.target.start < b.target.start
+ else
+ return u1 < u2
+ end
+ end)
+ -- 如果2个结果处于嵌套状态,则取范围小的那个
+ local lf, lu
+ for i = #results, 1, -1 do
+ local res = results[i].target
+ local f = res.finish
+ local uri = guide.getUri(res)
+ if lf and f > lf and uri == lu then
+ table.remove(results, i)
+ else
+ lu = uri
+ lf = f
+ end
+ end
+end
+
+local accept = {
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['label'] = true,
+ ['goto'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['string'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+
+ ['doc.type.name'] = true,
+ ['doc.class.name'] = true,
+ ['doc.extends.name'] = true,
+ ['doc.alias.name'] = true,
+}
+
+local function checkRequire(source, offset)
+ if source.type ~= 'string' then
+ return nil
+ end
+ local callargs = source.parent
+ if callargs.type ~= 'callargs' then
+ return
+ end
+ if callargs[1] ~= source then
+ return
+ end
+ local call = callargs.parent
+ local func = call.node
+ local literal = guide.getLiteral(source)
+ local libName = vm.getLibraryName(func)
+ if not libName then
+ return nil
+ end
+ if libName == 'require' then
+ return workspace.findUrisByRequirePath(literal)
+ elseif libName == 'dofile'
+ or libName == 'loadfile' then
+ return workspace.findUrisByFilePath(literal)
+ end
+ return nil
+end
+
+local function convertIndex(source)
+ if not source then
+ return
+ end
+ if source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number' then
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ return parent
+ end
+ end
+ return source
+end
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local source = convertIndex(findSource(ast, offset, accept))
+ if not source then
+ return nil
+ end
+
+ local results = {}
+ local uris = checkRequire(source)
+ if uris then
+ for i, uri in ipairs(uris) do
+ results[#results+1] = {
+ uri = files.getOriginUri(uri),
+ source = source,
+ target = {
+ start = 0,
+ finish = 0,
+ uri = uri,
+ }
+ }
+ end
+ end
+
+ for _, src in ipairs(vm.getDefs(source, 'deep')) do
+ local root = guide.getRoot(src)
+ if not root then
+ goto CONTINUE
+ end
+ src = src.field or src.method or src.index or src
+ if src.type == 'table' and src.parent.type ~= 'return' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.class.name'
+ and source.type ~= 'doc.type.name'
+ and source.type ~= 'doc.extends.name' then
+ goto CONTINUE
+ end
+ results[#results+1] = {
+ target = src,
+ uri = files.getOriginUri(root.uri),
+ source = source,
+ }
+ ::CONTINUE::
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ sortResults(results)
+
+ return results
+end
diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua
new file mode 100644
index 00000000..37815fb5
--- /dev/null
+++ b/script/core/diagnostics/ambiguity-1.lua
@@ -0,0 +1,69 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+local opMap = {
+ ['+'] = true,
+ ['-'] = true,
+ ['*'] = true,
+ ['/'] = true,
+ ['//'] = true,
+ ['^'] = true,
+ ['<<'] = true,
+ ['>>'] = true,
+ ['&'] = true,
+ ['|'] = true,
+ ['~'] = true,
+ ['..'] = true,
+}
+
+local literalMap = {
+ ['number'] = true,
+ ['boolean'] = true,
+ ['string'] = true,
+ ['table'] = true,
+}
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local text = files.getText(uri)
+ guide.eachSourceType(ast.ast, 'binary', function (source)
+ if source.op.type ~= 'or' then
+ return
+ end
+ local first = source[1]
+ local second = source[2]
+ -- a + (b or 0) --> (a + b) or 0
+ do
+ if opMap[first.op and first.op.type]
+ and first.type ~= 'unary'
+ and not second.op
+ and literalMap[second.type]
+ and not literalMap[first[2].type]
+ then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_AMBIGUITY_1', text:sub(first.start, first.finish))
+ }
+ end
+ end
+ -- (a or 0) + c --> a or (0 + c)
+ do
+ if opMap[second.op and second.op.type]
+ and second.type ~= 'unary'
+ and not first.op
+ and literalMap[second[1].type]
+ then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_AMBIGUITY_1', text:sub(second.start, second.finish))
+ }
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua
new file mode 100644
index 00000000..55179447
--- /dev/null
+++ b/script/core/diagnostics/circle-doc-class.lua
@@ -0,0 +1,54 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.class' then
+ if not doc.extends then
+ goto CONTINUE
+ end
+ local myName = guide.getName(doc)
+ local list = { doc }
+ local mark = {}
+ for i = 1, 999 do
+ local current = list[i]
+ if not current then
+ goto CONTINUE
+ end
+ if current.extends then
+ local newName = current.extends[1]
+ if newName == myName then
+ callback {
+ start = doc.start,
+ finish = doc.finish,
+ message = lang.script('DIAG_CIRCLE_DOC_CLASS', myName)
+ }
+ goto CONTINUE
+ end
+ if not mark[newName] then
+ mark[newName] = true
+ local docs = vm.getDocTypes(newName)
+ for _, otherDoc in ipairs(docs) do
+ if otherDoc.type == 'doc.class.name' then
+ list[#list+1] = otherDoc.parent
+ end
+ end
+ end
+ end
+ end
+ ::CONTINUE::
+ end
+ end
+end
diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua
new file mode 100644
index 00000000..a2bac8a4
--- /dev/null
+++ b/script/core/diagnostics/code-after-break.lua
@@ -0,0 +1,34 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ local mark = {}
+ guide.eachSourceType(state.ast, 'break', function (source)
+ local list = source.parent
+ if mark[list] then
+ return
+ end
+ mark[list] = true
+ for i = #list, 1, -1 do
+ local src = list[i]
+ if src == source then
+ if i == #list then
+ return
+ end
+ callback {
+ start = list[i+1].start,
+ finish = list[#list].range or list[#list].finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_CODE_AFTER_BREAK,
+ }
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/doc-field-no-class.lua b/script/core/diagnostics/doc-field-no-class.lua
new file mode 100644
index 00000000..f27bbb32
--- /dev/null
+++ b/script/core/diagnostics/doc-field-no-class.lua
@@ -0,0 +1,41 @@
+local files = require 'files'
+local lang = require 'language'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type ~= 'doc.field' then
+ goto CONTINUE
+ end
+ local bindGroup = doc.bindGroup
+ if not bindGroup then
+ goto CONTINUE
+ end
+ local ok
+ for _, other in ipairs(bindGroup) do
+ if other.type == 'doc.class' then
+ ok = true
+ break
+ end
+ if other == doc then
+ break
+ end
+ end
+ if not ok then
+ callback {
+ start = doc.start,
+ finish = doc.finish,
+ message = lang.script('DIAG_DOC_FIELD_NO_CLASS'),
+ }
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua
new file mode 100644
index 00000000..259c048b
--- /dev/null
+++ b/script/core/diagnostics/duplicate-doc-class.lua
@@ -0,0 +1,46 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ local cache = {}
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.class'
+ or doc.type == 'doc.alias' then
+ local name = guide.getName(doc)
+ if not cache[name] then
+ local docs = vm.getDocTypes(name)
+ cache[name] = {}
+ for _, otherDoc in ipairs(docs) do
+ if otherDoc.type == 'doc.class.name'
+ or otherDoc.type == 'doc.alias.name' then
+ cache[name][#cache[name]+1] = {
+ start = otherDoc.start,
+ finish = otherDoc.finish,
+ uri = guide.getUri(otherDoc),
+ }
+ end
+ end
+ end
+ if #cache[name] > 1 then
+ callback {
+ start = doc.start,
+ finish = doc.finish,
+ related = cache,
+ message = lang.script('DIAG_DUPLICATE_DOC_CLASS', name)
+ }
+ end
+ end
+ end
+end
diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua
new file mode 100644
index 00000000..b621fd9e
--- /dev/null
+++ b/script/core/diagnostics/duplicate-doc-field.lua
@@ -0,0 +1,34 @@
+local files = require 'files'
+local lang = require 'language'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ local mark
+ for _, group in ipairs(state.ast.docs.groups) do
+ for _, doc in ipairs(group) do
+ if doc.type == 'doc.class' then
+ mark = {}
+ elseif doc.type == 'doc.field' then
+ if mark then
+ local name = doc.field[1]
+ if mark[name] then
+ callback {
+ start = doc.field.start,
+ finish = doc.field.finish,
+ message = lang.script('DIAG_DUPLICATE_DOC_FIELD', name),
+ }
+ end
+ mark[name] = true
+ end
+ end
+ end
+ end
+end
diff --git a/script/core/diagnostics/duplicate-doc-param.lua b/script/core/diagnostics/duplicate-doc-param.lua
new file mode 100644
index 00000000..676a6fb4
--- /dev/null
+++ b/script/core/diagnostics/duplicate-doc-param.lua
@@ -0,0 +1,37 @@
+local files = require 'files'
+local lang = require 'language'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type ~= 'doc.param' then
+ goto CONTINUE
+ end
+ local name = doc.param[1]
+ local bindGroup = doc.bindGroup
+ if not bindGroup then
+ goto CONTINUE
+ end
+ for _, other in ipairs(bindGroup) do
+ if other ~= doc
+ and other.type == 'doc.param'
+ and other.param[1] == name then
+ callback {
+ start = doc.param.start,
+ finish = doc.param.finish,
+ message = lang.script('DIAG_DUPLICATE_DOC_PARAM', name)
+ }
+ goto CONTINUE
+ end
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua
new file mode 100644
index 00000000..dabe1b3c
--- /dev/null
+++ b/script/core/diagnostics/duplicate-index.lua
@@ -0,0 +1,63 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'table', function (source)
+ local mark = {}
+ for _, obj in ipairs(source) do
+ if obj.type == 'tablefield'
+ or obj.type == 'tableindex' then
+ local name = vm.getKeyName(obj)
+ if name then
+ if not mark[name] then
+ mark[name] = {}
+ end
+ mark[name][#mark[name]+1] = obj.field or obj.index
+ end
+ end
+ end
+
+ for name, defs in pairs(mark) do
+ local sname = name:match '^.|(.+)$'
+ if #defs > 1 and sname then
+ local related = {}
+ for i = 1, #defs do
+ local def = defs[i]
+ related[i] = {
+ start = def.start,
+ finish = def.finish,
+ uri = uri,
+ }
+ end
+ for i = 1, #defs - 1 do
+ local def = defs[i]
+ callback {
+ start = def.start,
+ finish = def.finish,
+ related = related,
+ message = lang.script('DIAG_DUPLICATE_INDEX', sname),
+ level = define.DiagnosticSeverity.Hint,
+ tags = { define.DiagnosticTag.Unnecessary },
+ }
+ end
+ for i = #defs, #defs do
+ local def = defs[i]
+ callback {
+ start = def.start,
+ finish = def.finish,
+ related = related,
+ message = lang.script('DIAG_DUPLICATE_INDEX', sname),
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua
new file mode 100644
index 00000000..2024f4e3
--- /dev/null
+++ b/script/core/diagnostics/empty-block.lua
@@ -0,0 +1,49 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+
+-- 检查空代码块
+-- 但是排除忙等待(repeat/while)
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'if', function (source)
+ for _, block in ipairs(source) do
+ if #block > 0 then
+ return
+ end
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+ guide.eachSourceType(ast.ast, 'loop', function (source)
+ if #source > 0 then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+ guide.eachSourceType(ast.ast, 'in', function (source)
+ if #source > 0 then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+end
diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua
new file mode 100644
index 00000000..9a0d4f35
--- /dev/null
+++ b/script/core/diagnostics/global-in-nil-env.lua
@@ -0,0 +1,66 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+-- TODO: 检查路径是否可达
+local function mayRun(path)
+ return true
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local root = guide.getRoot(ast.ast)
+ local env = guide.getENV(root)
+
+ local nilDefs = {}
+ if not env.ref then
+ return
+ end
+ for _, ref in ipairs(env.ref) do
+ if ref.type == 'setlocal' then
+ if ref.value and ref.value.type == 'nil' then
+ nilDefs[#nilDefs+1] = ref
+ end
+ end
+ end
+
+ if #nilDefs == 0 then
+ return
+ end
+
+ local function check(source)
+ local node = source.node
+ if node.tag == '_ENV' then
+ local ok
+ for _, nilDef in ipairs(nilDefs) do
+ local mode, pathA = guide.getPath(nilDef, source)
+ if mode == 'before'
+ and mayRun(pathA) then
+ ok = nilDef
+ break
+ end
+ end
+ if ok then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ uri = uri,
+ message = lang.script.DIAG_GLOBAL_IN_NIL_ENV,
+ related = {
+ {
+ start = ok.start,
+ finish = ok.finish,
+ uri = uri,
+ }
+ }
+ }
+ end
+ end
+ end
+
+ guide.eachSourceType(ast.ast, 'getglobal', check)
+ guide.eachSourceType(ast.ast, 'setglobal', check)
+end
diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua
new file mode 100644
index 00000000..a6b61e12
--- /dev/null
+++ b/script/core/diagnostics/init.lua
@@ -0,0 +1,56 @@
+local files = require 'files'
+local define = require 'proto.define'
+local config = require 'config'
+local await = require 'await'
+
+-- 把耗时最长的诊断放到最后面
+local diagLevel = {
+ ['redundant-parameter'] = 100,
+}
+
+local diagList = {}
+for k in pairs(define.DiagnosticDefaultSeverity) do
+ diagList[#diagList+1] = k
+end
+table.sort(diagList, function (a, b)
+ return (diagLevel[a] or 0) < (diagLevel[b] or 0)
+end)
+
+local function check(uri, name, results)
+ if config.config.diagnostics.disable[name] then
+ return
+ end
+ local level = config.config.diagnostics.severity[name]
+ or define.DiagnosticDefaultSeverity[name]
+ if level == 'Hint' and not files.isOpen(uri) then
+ return
+ end
+ local severity = define.DiagnosticSeverity[level]
+ local clock = os.clock()
+ require('core.diagnostics.' .. name)(uri, function (result)
+ result.level = severity or result.level
+ result.code = name
+ results[#results+1] = result
+ end, name)
+ local passed = os.clock() - clock
+ if passed >= 0.5 then
+ log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed))
+ end
+end
+
+return function (uri, response)
+ local vm = require 'vm'
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local isOpen = files.isOpen(uri)
+
+ for _, name in ipairs(diagList) do
+ await.delay()
+ local results = {}
+ check(uri, name, results)
+ response(results)
+ end
+end
diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua
new file mode 100644
index 00000000..fe5d1eca
--- /dev/null
+++ b/script/core/diagnostics/lowercase-global.lua
@@ -0,0 +1,56 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local config = require 'config'
+local vm = require 'vm'
+
+local function isDocClass(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.class' then
+ return true
+ end
+ end
+ return false
+end
+
+-- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删)
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local definedGlobal = {}
+ for name in pairs(config.config.diagnostics.globals) do
+ definedGlobal[name] = true
+ end
+
+ guide.eachSourceType(ast.ast, 'setglobal', function (source)
+ local name = guide.getName(source)
+ if definedGlobal[name] then
+ return
+ end
+ local first = name:match '%w'
+ if not first then
+ return
+ end
+ if not first:match '%l' then
+ return
+ end
+ -- 如果赋值被标记为 doc.class ,则认为是允许的
+ if isDocClass(source) then
+ return
+ end
+ if vm.isGlobalLibraryName(name) then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script.DIAG_LOWERCASE_GLOBAL,
+ }
+ end)
+end
diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua
new file mode 100644
index 00000000..75681cbc
--- /dev/null
+++ b/script/core/diagnostics/newfield-call.lua
@@ -0,0 +1,37 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+
+ guide.eachSourceType(ast.ast, 'table', function (source)
+ for i = 1, #source do
+ local field = source[i]
+ if field.type == 'call' then
+ local func = field.node
+ local args = field.args
+ if args then
+ local funcLine = guide.positionOf(lines, func.finish)
+ local argsLine = guide.positionOf(lines, args.start)
+ if argsLine > funcLine then
+ callback {
+ start = field.start,
+ finish = field.finish,
+ message = lang.script('DIAG_PREFIELD_CALL'
+ , text:sub(func.start, func.finish)
+ , text:sub(args.start, args.finish)
+ )
+ }
+ end
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua
new file mode 100644
index 00000000..cb318380
--- /dev/null
+++ b/script/core/diagnostics/newline-call.lua
@@ -0,0 +1,38 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local lines = files.getLines(uri)
+
+ guide.eachSourceType(ast.ast, 'call', function (source)
+ local node = source.node
+ local args = source.args
+ if not args then
+ return
+ end
+
+ -- 必须有其他人在继续使用当前对象
+ if not source.next then
+ return
+ end
+
+ local nodeRow = guide.positionOf(lines, node.finish)
+ local argRow = guide.positionOf(lines, args.start)
+ if nodeRow == argRow then
+ return
+ end
+
+ if #args == 1 then
+ callback {
+ start = args.start,
+ finish = args.finish,
+ message = lang.script.DIAG_PREVIOUS_CALL,
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua
new file mode 100644
index 00000000..5e53d837
--- /dev/null
+++ b/script/core/diagnostics/redefined-local.lua
@@ -0,0 +1,32 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ guide.eachSourceType(ast.ast, 'local', function (source)
+ local name = source[1]
+ if name == '_'
+ or name == ast.ENVMode then
+ return
+ end
+ local exist = guide.getLocal(source, name, source.start-1)
+ if exist then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_REDEFINED_LOCAL', name),
+ related = {
+ {
+ start = exist.start,
+ finish = exist.finish,
+ uri = uri,
+ }
+ },
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
new file mode 100644
index 00000000..2fae20e8
--- /dev/null
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -0,0 +1,82 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+local define = require 'proto.define'
+local await = require 'await'
+
+local function countCallArgs(source)
+ local result = 0
+ if not source.args then
+ return 0
+ end
+ if source.node and source.node.type == 'getmethod' then
+ result = result + 1
+ end
+ result = result + #source.args
+ return result
+end
+
+local function countFuncArgs(source)
+ local result = 0
+ if source.parent and source.parent.type == 'setmethod' then
+ result = result + 1
+ end
+ if not source.args then
+ return result
+ end
+ if source.args[#source.args].type == '...' then
+ return math.maxinteger
+ end
+ result = result + #source.args
+ return result
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'call', function (source)
+ local callArgs = countCallArgs(source)
+ if callArgs == 0 then
+ return
+ end
+
+ local func = source.node
+ local funcArgs
+ local defs = vm.getDefs(func)
+ for _, def in ipairs(defs) do
+ if def.value then
+ def = def.value
+ end
+ if def.type == 'function' then
+ local args = countFuncArgs(def)
+ if not funcArgs or args > funcArgs then
+ funcArgs = args
+ end
+ end
+ end
+
+ if not funcArgs then
+ return
+ end
+
+ local delta = callArgs - funcArgs
+ if delta <= 0 then
+ return
+ end
+ for i = #source.args - delta + 1, #source.args do
+ local arg = source.args[i]
+ if arg then
+ callback {
+ start = arg.start,
+ finish = arg.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs)
+ }
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/redundant-value.lua b/script/core/diagnostics/redundant-value.lua
new file mode 100644
index 00000000..be483448
--- /dev/null
+++ b/script/core/diagnostics/redundant-value.lua
@@ -0,0 +1,24 @@
+local files = require 'files'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback, code)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local diags = ast.diags[code]
+ if not diags then
+ return
+ end
+
+ for _, info in ipairs(diags) do
+ callback {
+ start = info.start,
+ finish = info.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_OVER_MAX_VALUES', info.max, info.passed)
+ }
+ end
+end
diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua
new file mode 100644
index 00000000..e54a6e60
--- /dev/null
+++ b/script/core/diagnostics/trailing-space.lua
@@ -0,0 +1,55 @@
+local files = require 'files'
+local lang = require 'language'
+local guide = require 'parser.guide'
+
+local function isInString(ast, offset)
+ local result = false
+ guide.eachSourceType(ast, 'string', function (source)
+ if offset >= source.start and offset <= source.finish then
+ result = true
+ end
+ end)
+ return result
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ for i = 1, #lines do
+ local start = lines[i].start
+ local range = lines[i].range
+ local lastChar = text:sub(range, range)
+ if lastChar ~= ' ' and lastChar ~= '\t' then
+ goto NEXT_LINE
+ end
+ if isInString(ast.ast, range) then
+ goto NEXT_LINE
+ end
+ local first = start
+ for n = range - 1, start, -1 do
+ local char = text:sub(n, n)
+ if char ~= ' ' and char ~= '\t' then
+ first = n + 1
+ break
+ end
+ end
+ if first == start then
+ callback {
+ start = first,
+ finish = range,
+ message = lang.script.DIAG_LINE_ONLY_SPACE,
+ }
+ else
+ callback {
+ start = first,
+ finish = range,
+ message = lang.script.DIAG_LINE_POST_SPACE,
+ }
+ end
+ ::NEXT_LINE::
+ end
+end
diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua
new file mode 100644
index 00000000..bbfdceec
--- /dev/null
+++ b/script/core/diagnostics/undefined-doc-class.lua
@@ -0,0 +1,46 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ local cache = {}
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.class' then
+ local ext = doc.extends
+ if not ext then
+ goto CONTINUE
+ end
+ local name = ext[1]
+ local docs = vm.getDocTypes(name)
+ if cache[name] == nil then
+ cache[name] = false
+ for _, otherDoc in ipairs(docs) do
+ if otherDoc.type == 'doc.class.name' then
+ cache[name] = true
+ break
+ end
+ end
+ end
+ if not cache[name] then
+ callback {
+ start = ext.start,
+ finish = ext.finish,
+ related = cache,
+ message = lang.script('DIAG_UNDEFINED_DOC_CLASS', name)
+ }
+ end
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua
new file mode 100644
index 00000000..5c1e8fbf
--- /dev/null
+++ b/script/core/diagnostics/undefined-doc-name.lua
@@ -0,0 +1,60 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+local function hasNameOfClassOrAlias(name)
+ local docs = vm.getDocTypes(name)
+ for _, otherDoc in ipairs(docs) do
+ if otherDoc.type == 'doc.class.name'
+ or otherDoc.type == 'doc.alias.name' then
+ return true
+ end
+ end
+ return false
+end
+
+local function hasNameOfGeneric(name, source)
+ if not source.typeGeneric then
+ return false
+ end
+ if not source.typeGeneric[name] then
+ return false
+ end
+ return true
+end
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ guide.eachSource(state.ast.docs, function (source)
+ if source.type ~= 'doc.extends.name'
+ and source.type ~= 'doc.type.name' then
+ return
+ end
+ if source.parent.type == 'doc.class' then
+ return
+ end
+ local name = source[1]
+ if name == '...' then
+ return
+ end
+ if hasNameOfClassOrAlias(name)
+ or hasNameOfGeneric(name, source) then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_UNDEFINED_DOC_NAME', name)
+ }
+ end)
+end
diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua
new file mode 100644
index 00000000..af3e07bc
--- /dev/null
+++ b/script/core/diagnostics/undefined-doc-param.lua
@@ -0,0 +1,52 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+local function hasParamName(func, name)
+ if not func.args then
+ return false
+ end
+ for _, arg in ipairs(func.args) do
+ if arg[1] == name then
+ return true
+ end
+ end
+ return false
+end
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type ~= 'doc.param' then
+ goto CONTINUE
+ end
+ local binds = doc.bindSources
+ if not binds then
+ goto CONTINUE
+ end
+ local param = doc.param
+ local name = param[1]
+ for _, source in ipairs(binds) do
+ if source.type == 'function' then
+ if not hasParamName(source, name) then
+ callback {
+ start = param.start,
+ finish = param.finish,
+ message = lang.script('DIAG_UNDEFINED_DOC_PARAM', name)
+ }
+ end
+ end
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua
new file mode 100644
index 00000000..6b8c62f0
--- /dev/null
+++ b/script/core/diagnostics/undefined-env-child.lua
@@ -0,0 +1,27 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ guide.eachSourceType(ast.ast, 'getglobal', function (source)
+ -- 单独验证自己是否在重载过的 _ENV 中有定义
+ if source.node.tag == '_ENV' then
+ return
+ end
+ local defs = guide.requestDefinition(source)
+ if #defs > 0 then
+ return
+ end
+ local key = source[1]
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_UNDEF_ENV_CHILD', key),
+ }
+ end)
+end
diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua
new file mode 100644
index 00000000..778fc1f1
--- /dev/null
+++ b/script/core/diagnostics/undefined-global.lua
@@ -0,0 +1,40 @@
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local config = require 'config'
+local guide = require 'parser.guide'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ -- 遍历全局变量,检查所有没有 set 模式的全局变量
+ guide.eachSourceType(ast.ast, 'getglobal', function (src)
+ local key = guide.getName(src)
+ if not key then
+ return
+ end
+ if config.config.diagnostics.globals[key] then
+ return
+ end
+ if #vm.getGlobalSets(guide.getKeyName(src)) > 0 then
+ return
+ end
+ local message = lang.script('DIAG_UNDEF_GLOBAL', key)
+ -- TODO check other version
+ local otherVersion
+ local customVersion
+ if otherVersion then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(otherVersion, '/'), config.config.runtime.version))
+ elseif customVersion then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_CUSTOM', table.concat(customVersion, '/')))
+ end
+ callback {
+ start = src.start,
+ finish = src.finish,
+ message = message,
+ }
+ end)
+end
diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua
new file mode 100644
index 00000000..f0bca613
--- /dev/null
+++ b/script/core/diagnostics/unused-function.lua
@@ -0,0 +1,40 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local define = require 'proto.define'
+local lang = require 'language'
+local await = require 'await'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ -- 只检查局部函数
+ guide.eachSourceType(ast.ast, 'function', function (source)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type ~= 'local'
+ and parent.type ~= 'setlocal' then
+ return
+ end
+ local hasGet
+ local refs = vm.getRefs(source)
+ for _, src in ipairs(refs) do
+ if vm.isGet(src) then
+ hasGet = true
+ break
+ end
+ end
+ if not hasGet then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_UNUSED_FUNCTION,
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua
new file mode 100644
index 00000000..e6d998ba
--- /dev/null
+++ b/script/core/diagnostics/unused-label.lua
@@ -0,0 +1,22 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'label', function (source)
+ if not source.ref then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNUSED_LABEL', source[1]),
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua
new file mode 100644
index 00000000..873a70f2
--- /dev/null
+++ b/script/core/diagnostics/unused-local.lua
@@ -0,0 +1,93 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+local function hasGet(loc)
+ if not loc.ref then
+ return false
+ end
+ local weak
+ for _, ref in ipairs(loc.ref) do
+ if ref.type == 'getlocal' then
+ if not ref.next then
+ return 'strong'
+ end
+ local nextType = ref.next.type
+ if nextType ~= 'setmethod'
+ and nextType ~= 'setfield'
+ and nextType ~= 'setindex' then
+ return 'strong'
+ else
+ weak = true
+ end
+ end
+ end
+ if weak then
+ return 'weak'
+ else
+ return nil
+ end
+end
+
+local function isMyTable(loc)
+ local value = loc.value
+ if value and value.type == 'table' then
+ return true
+ end
+ return false
+end
+
+local function isClose(source)
+ if not source.attrs then
+ return false
+ end
+ for _, attr in ipairs(source.attrs) do
+ if attr[1] == 'close' then
+ return true
+ end
+ end
+ return false
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ guide.eachSourceType(ast.ast, 'local', function (source)
+ local name = source[1]
+ if name == '_'
+ or name == ast.ENVMode then
+ return
+ end
+ if isClose(source) then
+ return
+ end
+ local data = hasGet(source)
+ if data == 'strong' then
+ return
+ end
+ if data == 'weak' then
+ if not isMyTable(source) then
+ return
+ end
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNUSED_LOCAL', name),
+ }
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback {
+ start = ref.start,
+ finish = ref.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNUSED_LOCAL', name),
+ }
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua
new file mode 100644
index 00000000..74cc08e7
--- /dev/null
+++ b/script/core/diagnostics/unused-vararg.lua
@@ -0,0 +1,31 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'function', function (source)
+ local args = source.args
+ if not args then
+ return
+ end
+
+ for _, arg in ipairs(args) do
+ if arg.type == '...' then
+ if not arg.ref then
+ callback {
+ start = arg.start,
+ finish = arg.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_UNUSED_VARARG,
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua
new file mode 100644
index 00000000..7392b337
--- /dev/null
+++ b/script/core/document-symbol.lua
@@ -0,0 +1,307 @@
+local await = require 'await'
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local util = require 'utility'
+
+local function buildName(source, text)
+ if source.type == 'setmethod'
+ or source.type == 'getmethod' then
+ if source.method then
+ return text:sub(source.start, source.method.finish)
+ end
+ end
+ if source.type == 'setfield'
+ or source.type == 'tablefield'
+ or source.type == 'getfield' then
+ if source.field then
+ return text:sub(source.start, source.field.finish)
+ end
+ end
+ return text:sub(source.start, source.finish)
+end
+
+local function buildFunctionParams(func)
+ if not func.args then
+ return ''
+ end
+ local params = {}
+ for i, arg in ipairs(func.args) do
+ if arg.type == '...' then
+ params[i] = '...'
+ else
+ params[i] = arg[1] or ''
+ end
+ end
+ return table.concat(params, ', ')
+end
+
+local function buildFunction(source, text, symbols)
+ local name = buildName(source, text)
+ local func = source.value
+ if source.type == 'tablefield'
+ or source.type == 'setfield' then
+ source = source.field
+ if not source then
+ return
+ end
+ end
+ local range, kind
+ if func.start > source.finish then
+ -- a = function()
+ range = { source.start, func.finish }
+ else
+ -- function f()
+ range = { func.start, func.finish }
+ end
+ if source.type == 'setmethod' then
+ kind = define.SymbolKind.Method
+ else
+ kind = define.SymbolKind.Function
+ end
+ symbols[#symbols+1] = {
+ name = name,
+ detail = ('function (%s)'):format(buildFunctionParams(func)),
+ kind = kind,
+ range = range,
+ selectionRange = { source.start, source.finish },
+ valueRange = { func.start, func.finish },
+ }
+end
+
+local function buildTable(tbl)
+ local buf = {}
+ for i = 1, 3 do
+ local field = tbl[i]
+ if not field then
+ break
+ end
+ if field.type == 'tablefield' then
+ buf[i] = ('%s'):format(field.field[1])
+ end
+ end
+ return table.concat(buf, ', ')
+end
+
+local function buildValue(source, text, symbols)
+ local name = buildName(source, text)
+ local range, sRange, valueRange, kind
+ local details = {}
+ if source.type == 'local' then
+ if source.parent.type == 'funcargs' then
+ details[1] = 'param'
+ range = { source.start, source.finish }
+ sRange = { source.start, source.finish }
+ kind = define.SymbolKind.Constant
+ else
+ details[1] = 'local'
+ range = { source.start, source.finish }
+ sRange = { source.start, source.finish }
+ kind = define.SymbolKind.Variable
+ end
+ elseif source.type == 'setlocal' then
+ details[1] = 'setlocal'
+ range = { source.start, source.finish }
+ sRange = { source.start, source.finish }
+ kind = define.SymbolKind.Variable
+ elseif source.type == 'setglobal' then
+ details[1] = 'global'
+ range = { source.start, source.finish }
+ sRange = { source.start, source.finish }
+ kind = define.SymbolKind.Class
+ elseif source.type == 'tablefield' then
+ if not source.field then
+ return
+ end
+ details[1] = 'field'
+ range = { source.field.start, source.field.finish }
+ sRange = { source.field.start, source.field.finish }
+ kind = define.SymbolKind.Property
+ elseif source.type == 'setfield' then
+ if not source.field then
+ return
+ end
+ details[1] = 'field'
+ range = { source.field.start, source.field.finish }
+ sRange = { source.field.start, source.field.finish }
+ kind = define.SymbolKind.Field
+ else
+ return
+ end
+ if source.value then
+ local literal = source.value[1]
+ if source.value.type == 'boolean' then
+ details[2] = ' boolean'
+ if literal ~= nil then
+ details[3] = ' = '
+ details[4] = util.viewLiteral(source.value[1])
+ end
+ elseif source.value.type == 'string' then
+ details[2] = ' string'
+ if literal ~= nil then
+ details[3] = ' = '
+ details[4] = util.viewLiteral(source.value[1])
+ end
+ elseif source.value.type == 'number' then
+ details[2] = ' number'
+ if literal ~= nil then
+ details[3] = ' = '
+ details[4] = util.viewLiteral(source.value[1])
+ end
+ elseif source.value.type == 'table' then
+ details[2] = ' {'
+ details[3] = buildTable(source.value)
+ details[4] = '}'
+ valueRange = { source.value.start, source.value.finish }
+ elseif source.value.type == 'select' then
+ if source.value.vararg and source.value.vararg.type == 'call' then
+ valueRange = { source.value.start, source.value.finish }
+ end
+ end
+ range = { range[1], source.value.finish }
+ end
+ symbols[#symbols+1] = {
+ name = name,
+ detail = table.concat(details),
+ kind = kind,
+ range = range,
+ selectionRange = sRange,
+ valueRange = valueRange,
+ }
+end
+
+local function buildSet(source, text, used, symbols)
+ local value = source.value
+ if value and value.type == 'function' then
+ used[value] = true
+ buildFunction(source, text, symbols)
+ else
+ buildValue(source, text, symbols)
+ end
+end
+
+local function buildAnonymousFunction(source, text, used, symbols)
+ if used[source] then
+ return
+ end
+ used[source] = true
+ local head = ''
+ local parent = source.parent
+ if parent.type == 'return' then
+ head = 'return '
+ elseif parent.type == 'callargs' then
+ local call = parent.parent
+ local node = call.node
+ head = buildName(node, text) .. ' -> '
+ end
+ symbols[#symbols+1] = {
+ name = '',
+ detail = ('%sfunction (%s)'):format(head, buildFunctionParams(source)),
+ kind = define.SymbolKind.Function,
+ range = { source.start, source.finish },
+ selectionRange = { source.start, source.start },
+ valueRange = { source.start, source.finish },
+ }
+end
+
+local function buildSource(source, text, used, symbols)
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal'
+ or source.type == 'setfield'
+ or source.type == 'setmethod'
+ or source.type == 'tablefield' then
+ await.delay()
+ buildSet(source, text, used, symbols)
+ elseif source.type == 'function' then
+ await.delay()
+ buildAnonymousFunction(source, text, used, symbols)
+ end
+end
+
+local function makeSymbol(uri)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local text = files.getText(uri)
+ local symbols = {}
+ local used = {}
+ guide.eachSource(ast.ast, function (source)
+ buildSource(source, text, used, symbols)
+ end)
+
+ return symbols
+end
+
+local function packChild(ranges, symbols)
+ await.delay()
+ table.sort(symbols, function (a, b)
+ return a.selectionRange[1] < b.selectionRange[1]
+ end)
+ await.delay()
+ local root = {
+ valueRange = { 0, math.maxinteger },
+ children = {},
+ }
+ local stacks = { root }
+ for _, symbol in ipairs(symbols) do
+ local parent = stacks[#stacks]
+ -- 移除已经超出生效范围的区间
+ while symbol.selectionRange[1] > parent.valueRange[2] do
+ stacks[#stacks] = nil
+ parent = stacks[#stacks]
+ end
+ -- 向后看,找出当前可能生效的区间
+ local nextRange
+ while #ranges > 0
+ and symbol.selectionRange[1] >= ranges[#ranges].valueRange[1] do
+ if symbol.selectionRange[1] <= ranges[#ranges].valueRange[2] then
+ nextRange = ranges[#ranges]
+ end
+ ranges[#ranges] = nil
+ end
+ if nextRange then
+ stacks[#stacks+1] = nextRange
+ parent = nextRange
+ end
+ if parent == symbol then
+ -- function f() end 的情况,selectionRange 在 valueRange 内部,
+ -- 当前区间置为上一层
+ parent = stacks[#stacks-1]
+ end
+ -- 把自己放到当前区间中
+ if not parent.children then
+ parent.children = {}
+ end
+ parent.children[#parent.children+1] = symbol
+ end
+ return root.children
+end
+
+local function packSymbols(symbols)
+ local ranges = {}
+ for _, symbol in ipairs(symbols) do
+ if symbol.valueRange then
+ ranges[#ranges+1] = symbol
+ end
+ end
+ await.delay()
+ table.sort(ranges, function (a, b)
+ return a.valueRange[1] > b.valueRange[1]
+ end)
+ -- 处理嵌套
+ return packChild(ranges, symbols)
+end
+
+return function (uri)
+ local symbols = makeSymbol(uri)
+ if not symbols then
+ return nil
+ end
+
+ local packedSymbols = packSymbols(symbols)
+
+ return packedSymbols
+end
diff --git a/script/core/find-source.lua b/script/core/find-source.lua
new file mode 100644
index 00000000..32de102c
--- /dev/null
+++ b/script/core/find-source.lua
@@ -0,0 +1,14 @@
+local guide = require 'parser.guide'
+
+return function (ast, offset, accept)
+ local len = math.huge
+ local result
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ local start, finish = guide.getStartFinish(source)
+ if finish - start < len and accept[source.type] then
+ result = source
+ len = finish - start
+ end
+ end)
+ return result
+end
diff --git a/script/core/highlight.lua b/script/core/highlight.lua
new file mode 100644
index 00000000..d7671df2
--- /dev/null
+++ b/script/core/highlight.lua
@@ -0,0 +1,252 @@
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm'
+local define = require 'proto.define'
+local findSource = require 'core.find-source'
+
+local function eachRef(source, callback)
+ local results = guide.requestReference(source)
+ for i = 1, #results do
+ callback(results[i])
+ end
+end
+
+local function eachField(source, callback)
+ local isGlobal = guide.isGlobal(source)
+ local results = guide.requestReference(source)
+ for i = 1, #results do
+ local res = results[i]
+ if isGlobal == guide.isGlobal(res) then
+ callback(res)
+ end
+ end
+end
+
+local function eachLocal(source, callback)
+ callback(source)
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback(ref)
+ end
+ end
+end
+
+local function find(source, uri, callback)
+ if source.type == 'local' then
+ eachLocal(source, callback)
+ elseif source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ eachLocal(source.node, callback)
+ elseif source.type == 'field'
+ or source.type == 'method' then
+ eachField(source.parent, callback)
+ elseif source.type == 'getindex'
+ or source.type == 'setindex'
+ or source.type == 'tableindex' then
+ eachField(source, callback)
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ eachField(source, callback)
+ elseif source.type == 'goto'
+ or source.type == 'label' then
+ eachRef(source, callback)
+ elseif source.type == 'string'
+ and source.parent.index == source then
+ eachField(source.parent, callback)
+ elseif source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number'
+ or source.type == 'nil' then
+ callback(source)
+ end
+end
+
+local function checkInIf(source, text, offset)
+ -- 检查 end
+ local endA = source.finish - #'end' + 1
+ local endB = source.finish
+ if offset >= endA
+ and offset <= endB
+ and text:sub(endA, endB) == 'end' then
+ return true
+ end
+ -- 检查每个子模块
+ for _, block in ipairs(source) do
+ for i = 1, #block.keyword, 2 do
+ local start = block.keyword[i]
+ local finish = block.keyword[i+1]
+ if offset >= start and offset <= finish then
+ return true
+ end
+ end
+ end
+ return false
+end
+
+local function makeIf(source, text, callback)
+ -- end
+ local endA = source.finish - #'end' + 1
+ local endB = source.finish
+ if text:sub(endA, endB) == 'end' then
+ callback(endA, endB)
+ end
+ -- 每个子模块
+ for _, block in ipairs(source) do
+ for i = 1, #block.keyword, 2 do
+ local start = block.keyword[i]
+ local finish = block.keyword[i+1]
+ callback(start, finish)
+ end
+ end
+ return false
+end
+
+local function findKeyWord(ast, text, offset, callback)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'do'
+ or source.type == 'function'
+ or source.type == 'loop'
+ or source.type == 'in'
+ or source.type == 'while'
+ or source.type == 'repeat' then
+ local ok
+ for i = 1, #source.keyword, 2 do
+ local start = source.keyword[i]
+ local finish = source.keyword[i+1]
+ if offset >= start and offset <= finish then
+ ok = true
+ break
+ end
+ end
+ if ok then
+ for i = 1, #source.keyword, 2 do
+ local start = source.keyword[i]
+ local finish = source.keyword[i+1]
+ callback(start, finish)
+ end
+ end
+ elseif source.type == 'if' then
+ local ok = checkInIf(source, text, offset)
+ if ok then
+ makeIf(source, text, callback)
+ end
+ end
+ end)
+end
+
+local accept = {
+ ['label'] = true,
+ ['goto'] = true,
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['tablefield'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['string'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+ ['nil'] = true,
+}
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local text = files.getText(uri)
+ local results = {}
+ local mark = {}
+
+ local source = findSource(ast, offset, accept)
+ if source then
+ find(source, uri, function (target)
+ local kind
+ if target.type == 'getfield' then
+ target = target.field
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setfield'
+ or target.type == 'tablefield' then
+ target = target.field
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getmethod' then
+ target = target.method
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setmethod' then
+ target = target.method
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getindex' then
+ target = target.index
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'field' then
+ if target.parent.type == 'getfield' then
+ kind = define.DocumentHighlightKind.Read
+ else
+ kind = define.DocumentHighlightKind.Write
+ end
+ elseif target.type == 'method' then
+ if target.parent.type == 'getmethod' then
+ kind = define.DocumentHighlightKind.Read
+ else
+ kind = define.DocumentHighlightKind.Write
+ end
+ elseif target.type == 'index' then
+ if target.parent.type == 'getindex' then
+ kind = define.DocumentHighlightKind.Read
+ else
+ kind = define.DocumentHighlightKind.Write
+ end
+ elseif target.type == 'index' then
+ if target.parent.type == 'getindex' then
+ kind = define.DocumentHighlightKind.Read
+ else
+ kind = define.DocumentHighlightKind.Write
+ end
+ elseif target.type == 'setindex'
+ or target.type == 'tableindex' then
+ target = target.index
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getlocal'
+ or target.type == 'getglobal'
+ or target.type == 'goto' then
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setlocal'
+ or target.type == 'local'
+ or target.type == 'setglobal'
+ or target.type == 'label' then
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'string'
+ or target.type == 'boolean'
+ or target.type == 'number'
+ or target.type == 'nil' then
+ kind = define.DocumentHighlightKind.Text
+ else
+ return
+ end
+ if mark[target] then
+ return
+ end
+ mark[target] = true
+ results[#results+1] = {
+ start = target.start,
+ finish = target.finish,
+ kind = kind,
+ }
+ end)
+ end
+
+ findKeyWord(ast, text, offset, function (start, finish)
+ results[#results+1] = {
+ start = start,
+ finish = finish,
+ kind = define.DocumentHighlightKind.Write
+ }
+ end)
+
+ if #results == 0 then
+ return nil
+ end
+ return results
+end
diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua
new file mode 100644
index 00000000..9cd19f02
--- /dev/null
+++ b/script/core/hover/arg.lua
@@ -0,0 +1,71 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local function optionalArg(arg)
+ if not arg.bindDocs then
+ return false
+ end
+ local name = arg[1]
+ for _, doc in ipairs(arg.bindDocs) do
+ if doc.type == 'doc.param' and doc.param[1] == name then
+ return doc.optional
+ end
+ end
+end
+
+local function asFunction(source, oop)
+ if not source.args then
+ return ''
+ end
+ local args = {}
+ for i = 1, #source.args do
+ local arg = source.args[i]
+ local name = arg.name or guide.getName(arg)
+ if name then
+ args[i] = ('%s%s: %s'):format(
+ name,
+ optionalArg(arg) and '?' or '',
+ vm.getInferType(arg)
+ )
+ else
+ args[i] = ('%s'):format(vm.getInferType(arg))
+ end
+ end
+ local methodDef
+ local parent = source.parent
+ if parent and parent.type == 'setmethod' then
+ methodDef = true
+ end
+ if not methodDef and oop then
+ return table.concat(args, ', ', 2)
+ else
+ return table.concat(args, ', ')
+ end
+end
+
+local function asDocFunction(source)
+ if not source.args then
+ return ''
+ end
+ local args = {}
+ for i = 1, #source.args do
+ local arg = source.args[i]
+ local name = arg.name[1]
+ args[i] = ('%s%s: %s'):format(
+ name,
+ arg.optional and '?' or '',
+ vm.getInferType(arg.extends)
+ )
+ end
+ return table.concat(args, ', ')
+end
+
+return function (source, oop)
+ if source.type == 'function' then
+ return asFunction(source, oop)
+ end
+ if source.type == 'doc.type.function' then
+ return asDocFunction(source)
+ end
+ return ''
+end
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
new file mode 100644
index 00000000..7d89ee6c
--- /dev/null
+++ b/script/core/hover/description.lua
@@ -0,0 +1,204 @@
+local vm = require 'vm'
+local ws = require 'workspace'
+local furi = require 'file-uri'
+local files = require 'files'
+local guide = require 'parser.guide'
+local markdown = require 'provider.markdown'
+local config = require 'config'
+local lang = require 'language'
+
+local function asStringInRequire(source, literal)
+ local rootPath = ws.path or ''
+ local parent = source.parent
+ if parent and parent.type == 'callargs' then
+ local result, searchers
+ local call = parent.parent
+ local func = call.node
+ local libName = vm.getLibraryName(func)
+ if not libName then
+ return
+ end
+ if libName == 'require' then
+ result, searchers = ws.findUrisByRequirePath(literal)
+ elseif libName == 'dofile'
+ or libName == 'loadfile' then
+ result = ws.findUrisByFilePath(literal)
+ end
+ if result and #result > 0 then
+ for i, uri in ipairs(result) do
+ local searcher = searchers and furi.decode(searchers[uri])
+ uri = files.getOriginUri(uri)
+ local path = furi.decode(uri)
+ if files.eq(path:sub(1, #rootPath), rootPath) then
+ path = path:sub(#rootPath + 1)
+ end
+ path = path:gsub('^[/\\]*', '')
+ if vm.isMetaFile(uri) then
+ result[i] = ('* [[meta]](%s)'):format(uri)
+ elseif searcher then
+ searcher = searcher:sub(#rootPath + 1)
+ searcher = ws.normalize(searcher)
+ result[i] = ('* [%s](%s) %s'):format(path, uri, lang.script('HOVER_USE_LUA_PATH', searcher))
+ else
+ result[i] = ('* [%s](%s)'):format(path, uri)
+ end
+ end
+ table.sort(result)
+ local md = markdown()
+ md:add('md', table.concat(result, '\n'))
+ return md:string()
+ end
+ end
+end
+
+local function asStringView(source, literal)
+ -- 内部包含转义符?
+ local rawLen = source.finish - source.start - 2 * #source[2] + 1
+ if config.config.hover.viewString
+ and (source[2] == '"' or source[2] == "'")
+ and rawLen > #literal then
+ local view = literal
+ local max = config.config.hover.viewStringMax
+ if #view > max then
+ view = view:sub(1, max) .. '...'
+ end
+ local md = markdown()
+ md:add('txt', view)
+ return md:string()
+ end
+end
+
+local function asString(source)
+ local literal = guide.getLiteral(source)
+ if type(literal) ~= 'string' then
+ return nil
+ end
+ return asStringInRequire(source, literal)
+ or asStringView(source, literal)
+end
+
+local function getBindComment(docGroup, base)
+ local lines = {}
+ for _, doc in ipairs(docGroup) do
+ if doc.type == 'doc.comment' then
+ lines[#lines+1] = doc.comment.text:sub(2)
+ elseif #lines > 0 and not base then
+ break
+ elseif doc == base then
+ break
+ else
+ lines = {}
+ end
+ end
+ if #lines == 0 then
+ return nil
+ end
+ return table.concat(lines, '\n')
+end
+
+local function buildEnumChunk(docType, name)
+ local enums = vm.getDocEnums(docType)
+ if #enums == 0 then
+ return
+ end
+ local types = {}
+ for _, tp in ipairs(docType.types) do
+ types[#types+1] = tp[1]
+ end
+ local lines = {}
+ lines[#lines+1] = ('%s: %s'):format(name, table.concat(types))
+ for _, enum in ipairs(enums) do
+ lines[#lines+1] = (' %s %s%s'):format(
+ (enum.default and '->')
+ or (enum.additional and '+>')
+ or ' |',
+ enum[1],
+ enum.comment and (' -- %s'):format(enum.comment) or ''
+ )
+ end
+ return table.concat(lines, '\n')
+end
+
+local function getBindEnums(docGroup)
+ local mark = {}
+ local chunks = {}
+ local returnIndex = 0
+ for _, doc in ipairs(docGroup) do
+ if doc.type == 'doc.param' then
+ local name = doc.param[1]
+ if mark[name] then
+ goto CONTINUE
+ end
+ mark[name] = true
+ chunks[#chunks+1] = buildEnumChunk(doc.extends, name)
+ elseif doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ returnIndex = returnIndex + 1
+ local name = rtn.name and rtn.name[1] or ('(return %d)'):format(returnIndex)
+ if mark[name] then
+ goto CONTINUE
+ end
+ mark[name] = true
+ chunks[#chunks+1] = buildEnumChunk(rtn, name)
+ end
+ end
+ ::CONTINUE::
+ end
+ if #chunks == 0 then
+ return nil
+ end
+ return table.concat(chunks, '\n\n')
+end
+
+local function tryDocFieldUpComment(source)
+ if source.type ~= 'doc.field' then
+ return
+ end
+ if not source.bindGroup then
+ return
+ end
+ local comment = getBindComment(source.bindGroup, source)
+ return comment
+end
+
+local function tryDocComment(source)
+ if not source.bindDocs then
+ return
+ end
+ local comment = getBindComment(source.bindDocs)
+ local enums = getBindEnums(source.bindDocs)
+ local md = markdown()
+ if comment then
+ md:add('md', comment)
+ end
+ if enums then
+ md:add('lua', enums)
+ end
+ return md:string()
+end
+
+local function tryDocOverloadToComment(source)
+ if source.type ~= 'doc.type.function' then
+ return
+ end
+ local doc = source.parent
+ if doc.type ~= 'doc.overload'
+ or not doc.bindSources then
+ return
+ end
+ for _, src in ipairs(doc.bindSources) do
+ local md = tryDocComment(src)
+ if md then
+ return md
+ end
+ end
+end
+
+return function (source)
+ if source.type == 'string' then
+ return asString(source)
+ end
+ return tryDocOverloadToComment(source)
+ or tryDocFieldUpComment(source)
+ or tryDocComment(source)
+end
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
new file mode 100644
index 00000000..96e01ab5
--- /dev/null
+++ b/script/core/hover/init.lua
@@ -0,0 +1,164 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local getLabel = require 'core.hover.label'
+local getDesc = require 'core.hover.description'
+local util = require 'utility'
+local findSource = require 'core.find-source'
+local lang = require 'language'
+
+local function eachFunctionAndOverload(value, callback)
+ callback(value)
+ if not value.bindDocs then
+ return
+ end
+ for _, doc in ipairs(value.bindDocs) do
+ if doc.type == 'doc.overload' then
+ callback(doc.overload)
+ end
+ end
+end
+
+local function getHoverAsFunction(source)
+ local values = vm.getDefs(source, 'deep')
+ local desc = getDesc(source)
+ local labels = {}
+ local defs = 0
+ local protos = 0
+ local other = 0
+ local oop = source.type == 'method'
+ or source.type == 'getmethod'
+ or source.type == 'setmethod'
+ local mark = {}
+ for _, def in ipairs(values) do
+ def = guide.getObjectValue(def) or def
+ if def.type == 'function'
+ or def.type == 'doc.type.function' then
+ eachFunctionAndOverload(def, function (value)
+ if mark[value] then
+ return
+ end
+ mark[value] =true
+ local label = getLabel(value, oop)
+ if label then
+ defs = defs + 1
+ labels[label] = (labels[label] or 0) + 1
+ if labels[label] == 1 then
+ protos = protos + 1
+ end
+ end
+ desc = desc or getDesc(value)
+ end)
+ elseif def.type == 'table'
+ or def.type == 'boolean'
+ or def.type == 'string'
+ or def.type == 'number' then
+ other = other + 1
+ desc = desc or getDesc(def)
+ end
+ end
+
+ if defs == 1 and other == 0 then
+ return {
+ label = next(labels),
+ source = source,
+ description = desc,
+ }
+ end
+
+ local lines = {}
+ if defs > 1 then
+ lines[#lines+1] = lang.script('HOVER_MULTI_DEF_PROTO', defs, protos)
+ end
+ if other > 0 then
+ lines[#lines+1] = lang.script('HOVER_MULTI_PROTO_NOT_FUNC', other)
+ end
+ if defs > 1 then
+ for label, count in util.sortPairs(labels) do
+ lines[#lines+1] = ('(%d) %s'):format(count, label)
+ end
+ else
+ lines[#lines+1] = next(labels)
+ end
+ local label = table.concat(lines, '\n')
+ return {
+ label = label,
+ source = source,
+ description = desc,
+ }
+end
+
+local function getHoverAsValue(source)
+ local oop = source.type == 'method'
+ or source.type == 'getmethod'
+ or source.type == 'setmethod'
+ local label = getLabel(source, oop)
+ local desc = getDesc(source)
+ if not desc then
+ local values = vm.getDefs(source, 'deep')
+ for _, def in ipairs(values) do
+ desc = getDesc(def)
+ if desc then
+ break
+ end
+ end
+ end
+ return {
+ label = label,
+ source = source,
+ description = desc,
+ }
+end
+
+local function getHoverAsDocName(source)
+ local label = getLabel(source)
+ local desc = getDesc(source)
+ return {
+ label = label,
+ source = source,
+ description = desc,
+ }
+end
+
+local function getHover(source)
+ if source.type == 'doc.type.name' then
+ return getHoverAsDocName(source)
+ end
+ local isFunction = vm.hasInferType(source, 'function', 'deep')
+ if isFunction then
+ return getHoverAsFunction(source)
+ else
+ return getHoverAsValue(source)
+ end
+end
+
+local accept = {
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['string'] = true,
+ ['number'] = true,
+ ['doc.type.name'] = true,
+}
+
+local function getHoverByUri(uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local source = findSource(ast, offset, accept)
+ if not source then
+ return nil
+ end
+ local hover = getHover(source)
+ return hover
+end
+
+return {
+ get = getHover,
+ byUri = getHoverByUri,
+}
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
new file mode 100644
index 00000000..d785bc27
--- /dev/null
+++ b/script/core/hover/label.lua
@@ -0,0 +1,211 @@
+local buildName = require 'core.hover.name'
+local buildArg = require 'core.hover.arg'
+local buildReturn = require 'core.hover.return'
+local buildTable = require 'core.hover.table'
+local vm = require 'vm'
+local util = require 'utility'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local config = require 'config'
+local files = require 'files'
+
+local function asFunction(source, oop)
+ local name = buildName(source, oop)
+ local arg = buildArg(source, oop)
+ local rtn = buildReturn(source)
+ local lines = {}
+ lines[1] = ('function %s(%s)'):format(name, arg)
+ lines[2] = rtn
+ return table.concat(lines, '\n')
+end
+
+local function asDocFunction(source)
+ local name = buildName(source)
+ local arg = buildArg(source)
+ local rtn = buildReturn(source)
+ local lines = {}
+ lines[1] = ('function %s(%s)'):format(name, arg)
+ lines[2] = rtn
+ return table.concat(lines, '\n')
+end
+
+local function asDocTypeName(source)
+ for _, doc in ipairs(vm.getDocTypes(source[1])) do
+ if doc.type == 'doc.class.name' then
+ return 'class ' .. source[1]
+ end
+ if doc.type == 'doc.alias.name' then
+ local extends = doc.parent.extends
+ return lang.script('HOVER_EXTENDS', vm.getInferType(extends))
+ end
+ end
+end
+
+local function asValue(source, title)
+ local name = buildName(source)
+ local infers = vm.getInfers(source, 'deep')
+ local type = vm.getInferType(source, 'deep')
+ local class = vm.getClass(source, 'deep')
+ local literal = vm.getInferLiteral(source, 'deep')
+ local cont
+ if type ~= 'string' and not type:find('%[%]$') then
+ if #vm.getFields(source, 'deep') > 0
+ or vm.hasInferType(source, 'table', 'deep') then
+ cont = buildTable(source)
+ end
+ end
+ local pack = {}
+ pack[#pack+1] = title
+ pack[#pack+1] = name .. ':'
+ if cont and type == 'table' then
+ type = nil
+ end
+ if class then
+ pack[#pack+1] = class
+ else
+ pack[#pack+1] = type
+ end
+ if literal then
+ pack[#pack+1] = '='
+ pack[#pack+1] = literal
+ end
+ if cont then
+ pack[#pack+1] = cont
+ end
+ return table.concat(pack, ' ')
+end
+
+local function asLocal(source)
+ return asValue(source, 'local')
+end
+
+local function asGlobal(source)
+ return asValue(source, 'global')
+end
+
+local function isGlobalField(source)
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ if source.type == 'setfield'
+ or source.type == 'getfield'
+ or source.type == 'setmethod'
+ or source.type == 'getmethod' then
+ local node = source.node
+ if node.type == 'setglobal'
+ or node.type == 'getglobal' then
+ return true
+ end
+ return isGlobalField(node)
+ elseif source.type == 'tablefield' then
+ local parent = source.parent
+ if parent.type == 'setglobal'
+ or parent.type == 'getglobal' then
+ return true
+ end
+ return isGlobalField(parent)
+ else
+ return false
+ end
+end
+
+local function asField(source)
+ if isGlobalField(source) then
+ return asGlobal(source)
+ end
+ return asValue(source, 'field')
+end
+
+local function asDocField(source)
+ local name = source.field[1]
+ local class
+ for _, doc in ipairs(source.bindGroup) do
+ if doc.type == 'doc.class' then
+ class = doc
+ break
+ end
+ end
+ if not class then
+ return ('field ?.%s: %s'):format(
+ name,
+ vm.getInferType(source.extends)
+ )
+ end
+ return ('field %s.%s: %s'):format(
+ class.class[1],
+ name,
+ vm.getInferType(source.extends)
+ )
+end
+
+local function asString(source)
+ local str = source[1]
+ if type(str) ~= 'string' then
+ return ''
+ end
+ local len = #str
+ local charLen = util.utf8Len(str, 1, -1)
+ if len == charLen then
+ return lang.script('HOVER_STRING_BYTES', len)
+ else
+ return lang.script('HOVER_STRING_CHARACTERS', len, charLen)
+ end
+end
+
+local function formatNumber(n)
+ local str = ('%.10f'):format(n)
+ str = str:gsub('%.?0*$', '')
+ return str
+end
+
+local function asNumber(source)
+ if not config.config.hover.viewNumber then
+ return nil
+ end
+ local num = source[1]
+ if type(num) ~= 'number' then
+ return nil
+ end
+ local uri = guide.getUri(source)
+ local text = files.getText(uri)
+ if not text then
+ return nil
+ end
+ local raw = text:sub(source.start, source.finish)
+ if not raw or not raw:find '[^%-%d%.]' then
+ return nil
+ end
+ return formatNumber(num)
+end
+
+return function (source, oop)
+ if source.type == 'function' then
+ return asFunction(source, oop)
+ elseif source.type == 'local'
+ or source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ return asLocal(source)
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return asGlobal(source)
+ elseif source.type == 'getfield'
+ or source.type == 'setfield'
+ or source.type == 'getmethod'
+ or source.type == 'setmethod'
+ or source.type == 'tablefield'
+ or source.type == 'field'
+ or source.type == 'method' then
+ return asField(source)
+ elseif source.type == 'string' then
+ return asString(source)
+ elseif source.type == 'number' then
+ return asNumber(source)
+ elseif source.type == 'doc.type.function' then
+ return asDocFunction(source)
+ elseif source.type == 'doc.type.name' then
+ return asDocTypeName(source)
+ elseif source.type == 'doc.field' then
+ return asDocField(source)
+ end
+end
diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua
new file mode 100644
index 00000000..9ad32e09
--- /dev/null
+++ b/script/core/hover/name.lua
@@ -0,0 +1,101 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local buildName
+
+local function asLocal(source)
+ local name = guide.getName(source)
+ if not source.attrs then
+ return name
+ end
+ local label = {}
+ label[#label+1] = name
+ for _, attr in ipairs(source.attrs) do
+ label[#label+1] = ('<%s>'):format(attr[1])
+ end
+ return table.concat(label, ' ')
+end
+
+local function asField(source, oop)
+ local class
+ if source.node.type ~= 'getglobal' then
+ class = vm.getClass(source.node, 'deep')
+ end
+ local node = class or guide.getName(source.node) or '?'
+ local method = guide.getName(source)
+ if oop then
+ return ('%s:%s'):format(node, method)
+ else
+ return ('%s.%s'):format(node, method)
+ end
+end
+
+local function asTableField(source)
+ if not source.field then
+ return
+ end
+ return guide.getName(source.field)
+end
+
+local function asGlobal(source)
+ return guide.getName(source)
+end
+
+local function asDocFunction(source)
+ local doc = guide.getParentType(source, 'doc.type')
+ or guide.getParentType(source, 'doc.overload')
+ if not doc or not doc.bindSources then
+ return ''
+ end
+ for _, src in ipairs(doc.bindSources) do
+ local name = buildName(src)
+ if name ~= '' then
+ return name
+ end
+ end
+ return ''
+end
+
+local function asDocField(source)
+ return source.field[1]
+end
+
+function buildName(source, oop)
+ if oop == nil then
+ oop = source.type == 'setmethod'
+ or source.type == 'getmethod'
+ end
+ if source.type == 'local'
+ or source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ return asLocal(source) or ''
+ end
+ if source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return asGlobal(source) or ''
+ end
+ if source.type == 'setmethod'
+ or source.type == 'getmethod' then
+ return asField(source, true) or ''
+ end
+ if source.type == 'setfield'
+ or source.type == 'getfield' then
+ return asField(source, oop) or ''
+ end
+ if source.type == 'tablefield' then
+ return asTableField(source) or ''
+ end
+ if source.type == 'doc.type.function' then
+ return asDocFunction(source)
+ end
+ if source.type == 'doc.field' then
+ return asDocField(source)
+ end
+ local parent = source.parent
+ if parent then
+ return buildName(parent, oop)
+ end
+ return ''
+end
+
+return buildName
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
new file mode 100644
index 00000000..3829dbed
--- /dev/null
+++ b/script/core/hover/return.lua
@@ -0,0 +1,125 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local function mergeTypes(returns)
+ if type(returns) == 'string' then
+ return returns
+ end
+ return guide.mergeTypes(returns)
+end
+
+local function getReturnDualByDoc(source)
+ local docs = source.bindDocs
+ if not docs then
+ return
+ end
+ local dual
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ if not dual then
+ dual = {}
+ end
+ dual[#dual+1] = { rtn }
+ end
+ end
+ end
+ return dual
+end
+
+local function getReturnDualByGrammar(source)
+ if not source.returns then
+ return nil
+ end
+ local dual
+ for _, rtn in ipairs(source.returns) do
+ if not dual then
+ dual = {}
+ end
+ for n = 1, #rtn do
+ if not dual[n] then
+ dual[n] = {}
+ end
+ dual[n][#dual[n]+1] = rtn[n]
+ end
+ end
+ return dual
+end
+
+local function asFunction(source)
+ local dual = getReturnDualByDoc(source)
+ or getReturnDualByGrammar(source)
+ if not dual then
+ return
+ end
+ local returns = {}
+ for i, rtn in ipairs(dual) do
+ local line = {}
+ local types = {}
+ if i == 1 then
+ line[#line+1] = ' -> '
+ else
+ line[#line+1] = ('% 3d. '):format(i)
+ end
+ for n = 1, #rtn do
+ local values = vm.getInfers(rtn[n])
+ for _, value in ipairs(values) do
+ if value.type then
+ for tp in value.type:gmatch '[^|]+' do
+ types[#types+1] = tp
+ end
+ end
+ end
+ end
+ if #types > 0 or rtn[1] then
+ local tp = mergeTypes(types) or 'any'
+ if rtn[1].name then
+ line[#line+1] = ('%s%s: %s'):format(
+ rtn[1].name[1],
+ rtn[1].optional and '?' or '',
+ tp
+ )
+ else
+ line[#line+1] = ('%s%s'):format(
+ tp,
+ rtn[1].optional and '?' or ''
+ )
+ end
+ else
+ break
+ end
+ returns[i] = table.concat(line)
+ end
+ if #returns == 0 then
+ return nil
+ end
+ return table.concat(returns, '\n')
+end
+
+local function asDocFunction(source)
+ if not source.returns or #source.returns == 0 then
+ return nil
+ end
+ local returns = {}
+ for i, rtn in ipairs(source.returns) do
+ local rtnText = ('%s%s'):format(
+ vm.getInferType(rtn),
+ rtn.optional and '?' or ''
+ )
+ if i == 1 then
+ returns[#returns+1] = (' -> %s'):format(rtnText)
+ else
+ returns[#returns+1] = ('% 3d. %s'):format(i, rtnText)
+ end
+ end
+ return table.concat(returns, '\n')
+end
+
+return function (source)
+ if source.type == 'function' then
+ return asFunction(source)
+ end
+ if source.type == 'doc.type.function' then
+ return asDocFunction(source)
+ end
+end
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
new file mode 100644
index 00000000..02be5271
--- /dev/null
+++ b/script/core/hover/table.lua
@@ -0,0 +1,257 @@
+local vm = require 'vm'
+local util = require 'utility'
+local guide = require 'parser.guide'
+local config = require 'config'
+local lang = require 'language'
+
+local function getKey(src)
+ local key = vm.getKeyName(src)
+ if not key or #key <= 2 then
+ if not src.index then
+ return '[any]'
+ end
+ local class = vm.getClass(src.index)
+ if class then
+ return ('[%s]'):format(class)
+ end
+ local tp = vm.getInferType(src.index)
+ if tp then
+ return ('[%s]'):format(tp)
+ end
+ return '[any]'
+ end
+ local ktype = key:sub(1, 2)
+ key = key:sub(3)
+ if ktype == 's|' then
+ if key:match '^[%a_][%w_]*$' then
+ return key
+ else
+ return ('[%s]'):format(util.viewLiteral(key))
+ end
+ end
+ return ('[%s]'):format(key)
+end
+
+local function getFieldFast(src)
+ local value = guide.getObjectValue(src) or src
+ if not value then
+ return 'any'
+ end
+ if value.type == 'boolean' then
+ return value.type, util.viewLiteral(value[1])
+ end
+ if value.type == 'number'
+ or value.type == 'integer' then
+ if math.tointeger(value[1]) then
+ if config.config.runtime.version == 'Lua 5.3'
+ or config.config.runtime.version == 'Lua 5.4' then
+ return 'integer', util.viewLiteral(value[1])
+ end
+ end
+ return value.type, util.viewLiteral(value[1])
+ end
+ if value.type == 'table'
+ or value.type == 'function' then
+ return value.type
+ end
+ if value.type == 'string' then
+ local literal = value[1]
+ if type(literal) == 'string' and #literal >= 50 then
+ literal = literal:sub(1, 47) .. '...'
+ end
+ return value.type, util.viewLiteral(literal)
+ end
+end
+
+local function getFieldFull(src)
+ local tp = vm.getInferType(src)
+ --local class = vm.getClass(src)
+ local literal = vm.getInferLiteral(src)
+ if type(literal) == 'string' and #literal >= 50 then
+ literal = literal:sub(1, 47) .. '...'
+ end
+ return tp, literal
+end
+
+local function getField(src, timeUp, mark, key)
+ if src.type == 'table'
+ or src.type == 'function' then
+ return nil
+ end
+ if src.parent then
+ if src.type == 'string'
+ or src.type == 'boolean'
+ or src.type == 'number'
+ or src.type == 'integer' then
+ if src.parent.type == 'tableindex'
+ or src.parent.type == 'setindex'
+ or src.parent.type == 'getindex' then
+ if src.parent.index == src then
+ src = src.parent
+ end
+ end
+ end
+ end
+ local tp, literal
+ tp, literal = getFieldFast(src)
+ if tp then
+ return tp, literal
+ end
+ if timeUp or mark[key] then
+ return nil
+ end
+ mark[key] = true
+ tp, literal = getFieldFull(src)
+ if tp then
+ return tp, literal
+ end
+ return nil
+end
+
+local function buildAsHash(classes, literals)
+ local keys = {}
+ for k in pairs(classes) do
+ keys[#keys+1] = k
+ end
+ table.sort(keys)
+ local lines = {}
+ lines[#lines+1] = '{'
+ for _, key in ipairs(keys) do
+ local class = classes[key]
+ local literal = literals[key]
+ if literal then
+ lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ else
+ lines[#lines+1] = (' %s: %s,'):format(key, class)
+ end
+ end
+ lines[#lines+1] = '}'
+ return table.concat(lines, '\n')
+end
+
+local function buildAsConst(classes, literals)
+ local keys = {}
+ for k in pairs(classes) do
+ keys[#keys+1] = k
+ end
+ table.sort(keys, function (a, b)
+ return tonumber(literals[a]) < tonumber(literals[b])
+ end)
+ local lines = {}
+ lines[#lines+1] = '{'
+ for _, key in ipairs(keys) do
+ local class = classes[key]
+ local literal = literals[key]
+ if literal then
+ lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ else
+ lines[#lines+1] = (' %s: %s,'):format(key, class)
+ end
+ end
+ lines[#lines+1] = '}'
+ return table.concat(lines, '\n')
+end
+
+local function mergeLiteral(literals)
+ local results = {}
+ local mark = {}
+ for _, value in ipairs(literals) do
+ if not mark[value] then
+ mark[value] = true
+ results[#results+1] = value
+ end
+ end
+ if #results == 0 then
+ return nil
+ end
+ table.sort(results)
+ return table.concat(results, '|')
+end
+
+local function mergeTypes(types)
+ local results = {}
+ local mark = {
+ -- 讲道理table的keyvalue不会是nil
+ ['nil'] = true,
+ }
+ for _, tv in ipairs(types) do
+ for tp in tv:gmatch '[^|]+' do
+ if not mark[tp] then
+ mark[tp] = true
+ results[#results+1] = tp
+ end
+ end
+ end
+ return guide.mergeTypes(results)
+end
+
+local function clearClasses(classes)
+ classes['[nil]'] = nil
+ classes['[any]'] = nil
+ classes['[string]'] = nil
+end
+
+return function (source)
+ local literals = {}
+ local classes = {}
+ local clock = os.clock()
+ local timeUp
+ local mark = {}
+ local fields = vm.getFields(source, 'deep')
+ local keyCount = 0
+ for _, src in ipairs(fields) do
+ local key = getKey(src)
+ if not key then
+ goto CONTINUE
+ end
+ if not classes[key] then
+ classes[key] = {}
+ keyCount = keyCount + 1
+ end
+ if not literals[key] then
+ literals[key] = {}
+ end
+ if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then
+ timeUp = true
+ end
+ local class, literal = getField(src, timeUp, mark, key)
+ if literal == 'nil' then
+ literal = nil
+ end
+ classes[key][#classes[key]+1] = class
+ literals[key][#literals[key]+1] = literal
+ if keyCount >= 1000 then
+ break
+ end
+ ::CONTINUE::
+ end
+
+ clearClasses(classes)
+
+ for key, class in pairs(classes) do
+ literals[key] = mergeLiteral(literals[key])
+ classes[key] = mergeTypes(class)
+ end
+
+ if not next(classes) then
+ return '{}'
+ end
+
+ local intValue = true
+ for key, class in pairs(classes) do
+ if class ~= 'integer' or not tonumber(literals[key]) then
+ intValue = false
+ break
+ end
+ end
+ local result
+ if intValue then
+ result = buildAsConst(classes, literals)
+ else
+ result = buildAsHash(classes, literals)
+ end
+ if timeUp then
+ result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result)
+ end
+ return result
+end
diff --git a/script/core/keyword.lua b/script/core/keyword.lua
new file mode 100644
index 00000000..1cbeb78d
--- /dev/null
+++ b/script/core/keyword.lua
@@ -0,0 +1,264 @@
+local define = require 'proto.define'
+local guide = require 'parser.guide'
+
+local keyWordMap = {
+ {'do', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'do .. end',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[$0 end]],
+ }
+ else
+ results[#results+1] = {
+ label = 'do .. end',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+do
+ $0
+end]],
+ }
+ end
+ return true
+ end, function (ast, start)
+ return guide.eachSourceContain(ast.ast, start, function (source)
+ if source.type == 'while'
+ or source.type == 'in'
+ or source.type == 'loop' then
+ for i = 1, #source.keyword do
+ if start == source.keyword[i] then
+ return true
+ end
+ end
+ end
+ end)
+ end},
+ {'and'},
+ {'break'},
+ {'else'},
+ {'elseif', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'elseif .. then',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[$1 then]],
+ }
+ else
+ results[#results+1] = {
+ label = 'elseif .. then',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[elseif $1 then]],
+ }
+ end
+ return true
+ end},
+ {'end'},
+ {'false'},
+ {'for', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'for .. in',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+${1:key, value} in ${2:pairs(${3:t})} do
+ $0
+end]]
+ }
+ results[#results+1] = {
+ label = 'for i = ..',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+${1:i} = ${2:1}, ${3:10, 1} do
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'for .. in',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+for ${1:key, value} in ${2:pairs(${3:t})} do
+ $0
+end]]
+ }
+ results[#results+1] = {
+ label = 'for i = ..',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+for ${1:i} = ${2:1}, ${3:10, 1} do
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+ {'function', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'function ()',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+$1($2)
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'function ()',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+function $1($2)
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+ {'goto'},
+ {'if', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'if .. then',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+$1 then
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'if .. then',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+if $1 then
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+ {'in', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'in ..',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+${1:pairs(${2:t})} do
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'in ..',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+in ${1:pairs(${2:t})} do
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+ {'local', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'local function',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+function $1($2)
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'local function',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+local function $1($2)
+ $0
+end]]
+ }
+ end
+ return false
+ end},
+ {'nil'},
+ {'not'},
+ {'or'},
+ {'repeat', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'repeat .. until',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[$0 until $1]]
+ }
+ else
+ results[#results+1] = {
+ label = 'repeat .. until',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+repeat
+ $0
+until $1]]
+ }
+ end
+ return true
+ end},
+ {'return', function (hasSpace, results)
+ if not hasSpace then
+ results[#results+1] = {
+ label = 'do return end',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[do return $1end]]
+ }
+ end
+ return false
+ end},
+ {'then'},
+ {'true'},
+ {'until'},
+ {'while', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'while .. do',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+${1:true} do
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'while .. do',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+while ${1:true} do
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+}
+
+return keyWordMap
diff --git a/script/core/matchkey.lua b/script/core/matchkey.lua
new file mode 100644
index 00000000..45c86eff
--- /dev/null
+++ b/script/core/matchkey.lua
@@ -0,0 +1,33 @@
+return function (me, other, fast)
+ if me == other then
+ return true
+ end
+ if me == '' then
+ return true
+ end
+ if #me > #other then
+ return false
+ end
+ local lMe = me:lower()
+ local lOther = other:lower()
+ if lMe == lOther:sub(1, #lMe) then
+ return true
+ end
+ if fast and me:sub(1, 1) ~= other:sub(1, 1) then
+ return false
+ end
+ local chars = {}
+ for i = 1, #lOther do
+ local c = lOther:sub(i, i)
+ chars[c] = (chars[c] or 0) + 1
+ end
+ for i = 1, #lMe do
+ local c = lMe:sub(i, i)
+ if chars[c] and chars[c] > 0 then
+ chars[c] = chars[c] - 1
+ else
+ return false
+ end
+ end
+ return true
+end
diff --git a/script/core/reference.lua b/script/core/reference.lua
new file mode 100644
index 00000000..d7d3df03
--- /dev/null
+++ b/script/core/reference.lua
@@ -0,0 +1,116 @@
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm'
+local findSource = require 'core.find-source'
+
+local function isValidFunction(source, offset)
+ -- 必须点在 `function` 这个单词上才能查找函数引用
+ return offset >= source.start and offset < source.start + #'function'
+end
+
+local function sortResults(results)
+ -- 先按照顺序排序
+ table.sort(results, function (a, b)
+ local u1 = guide.getUri(a.target)
+ local u2 = guide.getUri(b.target)
+ if u1 == u2 then
+ return a.target.start < b.target.start
+ else
+ return u1 < u2
+ end
+ end)
+ -- 如果2个结果处于嵌套状态,则取范围小的那个
+ local lf, lu
+ for i = #results, 1, -1 do
+ local res = results[i].target
+ local f = res.finish
+ local uri = guide.getUri(res)
+ if lf and f > lf and uri == lu then
+ table.remove(results, i)
+ else
+ lu = uri
+ lf = f
+ end
+ end
+end
+
+local accept = {
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['label'] = true,
+ ['goto'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['setindex'] = true,
+ ['getindex'] = true,
+ ['tableindex'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['function'] = true,
+
+ ['doc.type.name'] = true,
+ ['doc.class.name'] = true,
+ ['doc.extends.name'] = true,
+ ['doc.alias.name'] = true,
+}
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local source = findSource(ast, offset, accept)
+ if not source then
+ return nil
+ end
+ if source.type == 'function' and not isValidFunction(source, offset) and not TEST then
+ return nil
+ end
+
+ local results = {}
+ for _, src in ipairs(vm.getRefs(source, 'deep')) do
+ local root = guide.getRoot(src)
+ if not root then
+ goto CONTINUE
+ end
+ if vm.isMetaFile(root.uri) then
+ goto CONTINUE
+ end
+ if ( src.type == 'doc.class.name'
+ or src.type == 'doc.type.name'
+ )
+ and source.type ~= 'doc.type.name'
+ and source.type ~= 'doc.class.name' then
+ goto CONTINUE
+ end
+ if src.type == 'setfield'
+ or src.type == 'getfield'
+ or src.type == 'tablefield' then
+ src = src.field
+ elseif src.type == 'setindex'
+ or src.type == 'getindex'
+ or src.type == 'tableindex' then
+ src = src.index
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ src = src.method
+ elseif src.type == 'table' and src.parent.type ~= 'return' then
+ goto CONTINUE
+ end
+ results[#results+1] = {
+ target = src,
+ uri = files.getOriginUri(root.uri),
+ }
+ ::CONTINUE::
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ sortResults(results)
+
+ return results
+end
diff --git a/script/core/rename.lua b/script/core/rename.lua
new file mode 100644
index 00000000..89298bdd
--- /dev/null
+++ b/script/core/rename.lua
@@ -0,0 +1,448 @@
+local files = require 'files'
+local vm = require 'vm'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local define = require 'proto.define'
+local util = require 'utility'
+local findSource = require 'core.find-source'
+local ws = require 'workspace'
+
+local Forcing
+
+local function askForcing(str)
+ -- TODO 总是可以替换
+ do return true end
+ if TEST then
+ return true
+ end
+ if Forcing ~= nil then
+ return Forcing
+ end
+ local version = files.globalVersion
+ -- TODO
+ local item = proto.awaitRequest('window/showMessageRequest', {
+ type = define.MessageType.Warning,
+ message = ('[%s]不是有效的标识符,是否强制替换?'):format(str),
+ actions = {
+ {
+ title = '强制替换',
+ },
+ {
+ title = '取消',
+ },
+ }
+ })
+ if version ~= files.globalVersion then
+ Forcing = false
+ proto.notify('window/showMessage', {
+ type = define.MessageType.Warning,
+ message = '文件发生了变化,替换取消。'
+ })
+ return false
+ end
+ if not item then
+ Forcing = false
+ return false
+ end
+ if item.title == '强制替换' then
+ Forcing = true
+ return true
+ else
+ Forcing = false
+ return false
+ end
+end
+
+local function askForMultiChange(results, newname)
+ -- TODO 总是可以替换
+ do return true end
+ if TEST then
+ return true
+ end
+ local uris = {}
+ for _, result in ipairs(results) do
+ local uri = result.uri
+ if not uris[uri] then
+ uris[uri] = 0
+ uris[#uris+1] = uri
+ end
+ uris[uri] = uris[uri] + 1
+ end
+ if #uris <= 1 then
+ return true
+ end
+
+ local version = files.globalVersion
+ -- TODO
+ local item = proto.awaitRequest('window/showMessageRequest', {
+ type = define.MessageType.Warning,
+ message = ('将修改 %d 个文件,共 %d 处。'):format(
+ #uris,
+ #results
+ ),
+ actions = {
+ {
+ title = '继续',
+ },
+ {
+ title = '放弃',
+ },
+ }
+ })
+ if version ~= files.globalVersion then
+ proto.notify('window/showMessage', {
+ type = define.MessageType.Warning,
+ message = '文件发生了变化,替换取消。'
+ })
+ return false
+ end
+ if item and item.title == '继续' then
+ local fileList = {}
+ for _, uri in ipairs(uris) do
+ fileList[#fileList+1] = ('%s (%d)'):format(uri, uris[uri])
+ end
+
+ log.debug(('Renamed [%s]\r\n%s'):format(newname, table.concat(fileList, '\r\n')))
+ return true
+ end
+ return false
+end
+
+local function trim(str)
+ return str:match '^%s*(%S+)%s*$'
+end
+
+local function isValidName(str)
+ return str:match '^[%a_][%w_]*$'
+end
+
+local function isValidGlobal(str)
+ for s in str:gmatch '[^%.]*' do
+ if not isValidName(trim(s)) then
+ return false
+ end
+ end
+ return true
+end
+
+local function isValidFunctionName(str)
+ if isValidGlobal(str) then
+ return true
+ end
+ local pos = str:find(':', 1, true)
+ if not pos then
+ return false
+ end
+ return isValidGlobal(trim(str:sub(1, pos-1)))
+ and isValidName(trim(str:sub(pos+1)))
+end
+
+local function isFunctionGlobalName(source)
+ local parent = source.parent
+ if parent.type ~= 'setglobal' then
+ return false
+ end
+ local value = parent.value
+ if not value.type ~= 'function' then
+ return false
+ end
+ return value.start <= parent.start
+end
+
+local function renameLocal(source, newname, callback)
+ if isValidName(newname) then
+ callback(source, source.start, source.finish, newname)
+ return
+ end
+ if askForcing(newname) then
+ callback(source, source.start, source.finish, newname)
+ end
+end
+
+local function renameField(source, newname, callback)
+ if isValidName(newname) then
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ local parent = source.parent
+ if parent.type == 'setfield'
+ or parent.type == 'getfield' then
+ local dot = parent.dot
+ local newstr = '[' .. util.viewString(newname) .. ']'
+ callback(source, dot.start, source.finish, newstr)
+ elseif parent.type == 'tablefield' then
+ local newstr = '[' .. util.viewString(newname) .. ']'
+ callback(source, source.start, source.finish, newstr)
+ elseif parent.type == 'getmethod' then
+ if not askForcing(newname) then
+ return false
+ end
+ callback(source, source.start, source.finish, newname)
+ elseif parent.type == 'setmethod' then
+ local uri = guide.getUri(source)
+ local text = files.getText(uri)
+ local func = parent.value
+ -- function mt:name () end --> mt['newname'] = function (self) end
+ local newstr = string.format('%s[%s] = function '
+ , text:sub(parent.start, parent.node.finish)
+ , util.viewString(newname)
+ )
+ callback(source, func.start, parent.finish, newstr)
+ local pl = text:find('(', parent.finish, true)
+ if pl then
+ if func.args then
+ callback(source, pl + 1, pl, 'self, ')
+ else
+ callback(source, pl + 1, pl, 'self')
+ end
+ end
+ end
+ return true
+end
+
+local function renameGlobal(source, newname, callback)
+ if isValidGlobal(newname) then
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ if isValidFunctionName(newname) then
+ if not isFunctionGlobalName(source) then
+ askForcing(newname)
+ end
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ local newstr = '_ENV[' .. util.viewString(newname) .. ']'
+ -- function name () end --> _ENV['newname'] = function () end
+ if source.value and source.value.type == 'function'
+ and source.value.start < source.start then
+ callback(source, source.value.start, source.finish, newstr .. ' = function ')
+ return true
+ end
+ callback(source, source.start, source.finish, newstr)
+ return true
+end
+
+local function ofLocal(source, newname, callback)
+ renameLocal(source, newname, callback)
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ renameLocal(ref, newname, callback)
+ end
+ end
+end
+
+local function ofFieldThen(key, src, newname, callback)
+ if vm.getKeyName(src) ~= key then
+ return
+ end
+ if src.type == 'tablefield'
+ or src.type == 'getfield'
+ or src.type == 'setfield' then
+ src = src.field
+ elseif src.type == 'tableindex'
+ or src.type == 'getindex'
+ or src.type == 'setindex' then
+ src = src.index
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ src = src.method
+ end
+ if src.type == 'string' then
+ local quo = src[2]
+ local text = util.viewString(newname, quo)
+ callback(src, src.start, src.finish, text)
+ return
+ elseif src.type == 'field'
+ or src.type == 'method' then
+ local suc = renameField(src, newname, callback)
+ if not suc then
+ return
+ end
+ elseif src.type == 'setglobal'
+ or src.type == 'getglobal' then
+ local suc = renameGlobal(src, newname, callback)
+ if not suc then
+ return
+ end
+ end
+end
+
+local function ofField(source, newname, callback)
+ local key = guide.getKeyName(source)
+ local node
+ if source.type == 'tablefield'
+ or source.type == 'tableindex' then
+ node = source.parent
+ else
+ node = source.node
+ end
+ for _, src in ipairs(vm.getFields(node, 'deep')) do
+ ofFieldThen(key, src, newname, callback)
+ end
+end
+
+local function ofGlobal(source, newname, callback)
+ local key = guide.getKeyName(source)
+ for _, src in ipairs(vm.getRefs(source, 'deep')) do
+ ofFieldThen(key, src, newname, callback)
+ end
+end
+
+local function ofLabel(source, newname, callback)
+ if not isValidName(newname) and not askForcing(newname)then
+ return false
+ end
+ for _, src in ipairs(vm.getRefs(source, 'deep')) do
+ callback(src, src.start, src.finish, newname)
+ end
+end
+
+local function rename(source, newname, callback)
+ if source.type == 'label'
+ or source.type == 'goto' then
+ return ofLabel(source, newname, callback)
+ elseif source.type == 'local' then
+ return ofLocal(source, newname, callback)
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ return ofLocal(source.node, newname, callback)
+ elseif source.type == 'field'
+ or source.type == 'method'
+ or source.type == 'index' then
+ return ofField(source.parent, newname, callback)
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return ofGlobal(source, newname, callback)
+ elseif source.type == 'string'
+ or source.type == 'number'
+ or source.type == 'boolean' then
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ return ofField(parent, newname, callback)
+ end
+ end
+ return
+end
+
+local function prepareRename(source)
+ if source.type == 'label'
+ or source.type == 'goto'
+ or source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'getlocal'
+ or source.type == 'field'
+ or source.type == 'method'
+ or source.type == 'tablefield'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return source, source[1]
+ elseif source.type == 'string'
+ or source.type == 'number'
+ or source.type == 'boolean' then
+ local parent = source.parent
+ if not parent then
+ return nil
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ return source, source[1]
+ end
+ return nil
+ end
+ return nil
+end
+
+local accept = {
+ ['label'] = true,
+ ['goto'] = true,
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['tablefield'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['string'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+}
+
+local m = {}
+
+function m.rename(uri, pos, newname)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local source = findSource(ast, pos, accept)
+ if not source then
+ return nil
+ end
+ local results = {}
+ local mark = {}
+
+ rename(source, newname, function (target, start, finish, text)
+ local turi = files.getOriginUri(guide.getUri(target))
+ local uid = turi .. start
+ if mark[uid] then
+ return
+ end
+ mark[uid] = true
+ if files.isLibrary(turi) then
+ return
+ end
+ results[#results+1] = {
+ start = start,
+ finish = finish,
+ text = text,
+ uri = turi,
+ }
+ end)
+
+ if Forcing == false then
+ Forcing = nil
+ return nil
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ if not askForMultiChange(results, newname) then
+ return nil
+ end
+
+ return results
+end
+
+function m.prepareRename(uri, pos)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local source = findSource(ast, pos, accept)
+ if not source then
+ return
+ end
+
+ local res, text = prepareRename(source)
+ if not res then
+ return nil
+ end
+
+ return {
+ start = source.start,
+ finish = source.finish,
+ text = text,
+ }
+end
+
+return m
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
new file mode 100644
index 00000000..e6b35cdd
--- /dev/null
+++ b/script/core/semantic-tokens.lua
@@ -0,0 +1,161 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local await = require 'await'
+local define = require 'proto.define'
+local vm = require 'vm'
+local util = require 'utility'
+
+local Care = {}
+Care['setglobal'] = function (source, results)
+ local isLib = vm.isGlobalLibraryName(source[1])
+ if not isLib then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.namespace,
+ modifieres = define.TokenModifiers.deprecated,
+ }
+ end
+end
+Care['getglobal'] = function (source, results)
+ local isLib = vm.isGlobalLibraryName(source[1])
+ if not isLib then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.namespace,
+ modifieres = define.TokenModifiers.deprecated,
+ }
+ end
+end
+Care['tablefield'] = function (source, results)
+ local field = source.field
+ if not field then
+ return
+ end
+ results[#results+1] = {
+ start = field.start,
+ finish = field.finish,
+ type = define.TokenTypes.property,
+ modifieres = define.TokenModifiers.declaration,
+ }
+end
+Care['getlocal'] = function (source, results)
+ local loc = source.node
+ -- 1. 值为函数的局部变量
+ local hasFunc
+ local node = loc.node
+ if node then
+ for _, ref in ipairs(node.ref) do
+ local def = ref.value
+ if def.type == 'function' then
+ hasFunc = true
+ break
+ end
+ end
+ end
+ if hasFunc then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.interface,
+ modifieres = define.TokenModifiers.declaration,
+ }
+ return
+ end
+ -- 2. 对象
+ if source.parent.type == 'getmethod'
+ and source.parent.node == source then
+ return
+ end
+ -- 3. 函数的参数
+ if loc.parent and loc.parent.type == 'funcargs' then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.parameter,
+ modifieres = define.TokenModifiers.declaration,
+ }
+ return
+ end
+ -- 4. 特殊变量
+ if source[1] == '_ENV'
+ or source[1] == 'self' then
+ return
+ end
+ -- 5. 其他
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.variable,
+ }
+end
+Care['setlocal'] = Care['getlocal']
+Care['doc.return.name'] = function (source, results)
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.parameter,
+ }
+end
+
+local function buildTokens(results, text, lines)
+ local tokens = {}
+ local lastLine = 0
+ local lastStartChar = 0
+ for i, source in ipairs(results) do
+ local row, col = guide.positionOf(lines, source.start)
+ local start = guide.lineRange(lines, row)
+ local ucol = util.utf8Len(text, start, start + col - 1)
+ local line = row - 1
+ local startChar = ucol - 1
+ local deltaLine = line - lastLine
+ local deltaStartChar
+ if deltaLine == 0 then
+ deltaStartChar = startChar - lastStartChar
+ else
+ deltaStartChar = startChar
+ end
+ lastLine = line
+ lastStartChar = startChar
+ -- see https://microsoft.github.io/language-server-protocol/specifications/specification-3-16/#textDocument_semanticTokens
+ local len = i * 5 - 5
+ tokens[len + 1] = deltaLine
+ tokens[len + 2] = deltaStartChar
+ tokens[len + 3] = source.finish - source.start + 1 -- length
+ tokens[len + 4] = source.type
+ tokens[len + 5] = source.modifieres or 0
+ end
+ return tokens
+end
+
+return function (uri, start, finish)
+ local ast = files.getAst(uri)
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ if not ast then
+ return nil
+ end
+
+ local results = {}
+ local count = 0
+ guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ local method = Care[source.type]
+ if not method then
+ return
+ end
+ method(source, results)
+ count = count + 1
+ if count % 100 == 0 then
+ await.delay()
+ end
+ end)
+
+ table.sort(results, function (a, b)
+ return a.start < b.start
+ end)
+
+ local tokens = buildTokens(results, text, lines)
+
+ return tokens
+end
diff --git a/script/core/signature.lua b/script/core/signature.lua
new file mode 100644
index 00000000..dad38924
--- /dev/null
+++ b/script/core/signature.lua
@@ -0,0 +1,106 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local hoverLabel = require 'core.hover.label'
+local hoverDesc = require 'core.hover.description'
+
+local function findNearCall(uri, ast, pos)
+ local text = files.getText(uri)
+ -- 检查 `f()$` 的情况,注意要区别于 `f($`
+ if text:sub(pos, pos) == ')' then
+ return nil
+ end
+
+ local nearCall
+ guide.eachSourceContain(ast.ast, pos, function (src)
+ if src.type == 'call'
+ or src.type == 'table'
+ or src.type == 'function' then
+ if not nearCall or nearCall.start < src.start then
+ nearCall = src
+ end
+ end
+ end)
+ if not nearCall then
+ return nil
+ end
+ if nearCall.type ~= 'call' then
+ return nil
+ end
+ return nearCall
+end
+
+local function makeOneSignature(source, oop, index)
+ local label = hoverLabel(source, oop)
+ -- 去掉返回值
+ label = label:gsub('%s*->.+', '')
+ local params = {}
+ local i = 0
+ for start, finish in label:gmatch '[%(%)%,]%s*().-()%s*%f[%(%)%,%[%]]' do
+ i = i + 1
+ params[i] = {
+ label = {start, finish-1},
+ }
+ end
+ -- 不定参数
+ if index > i and i > 0 then
+ local lastLabel = params[i].label
+ local text = label:sub(lastLabel[1], lastLabel[2])
+ if text == '...' then
+ index = i
+ end
+ end
+ return {
+ label = label,
+ params = params,
+ index = index,
+ description = hoverDesc(source),
+ }
+end
+
+local function makeSignatures(call, pos)
+ local node = call.node
+ local oop = node.type == 'method'
+ or node.type == 'getmethod'
+ or node.type == 'setmethod'
+ local index
+ local args = call.args
+ if args then
+ for i, arg in ipairs(args) do
+ if arg.start <= pos and arg.finish >= pos then
+ index = i
+ break
+ end
+ end
+ if not index then
+ index = #args + 1
+ end
+ else
+ index = 1
+ end
+ local signs = {}
+ local defs = vm.getDefs(node, 'deep')
+ for _, src in ipairs(defs) do
+ if src.type == 'function'
+ or src.type == 'doc.type.function' then
+ signs[#signs+1] = makeOneSignature(src, oop, index)
+ end
+ end
+ return signs
+end
+
+return function (uri, pos)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local call = findNearCall(uri, ast, pos)
+ if not call then
+ return nil
+ end
+ local signs = makeSignatures(call, pos)
+ if not signs or #signs == 0 then
+ return nil
+ end
+ return signs
+end
diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua
new file mode 100644
index 00000000..4fc6a854
--- /dev/null
+++ b/script/core/workspace-symbol.lua
@@ -0,0 +1,69 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local matchKey = require 'core.matchkey'
+local define = require 'proto.define'
+local await = require 'await'
+
+local function buildSource(uri, source, key, results)
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal' then
+ local name = source[1]
+ if matchKey(key, name) then
+ results[#results+1] = {
+ name = name,
+ kind = define.SymbolKind.Variable,
+ uri = uri,
+ range = { source.start, source.finish },
+ }
+ end
+ elseif source.type == 'setfield'
+ or source.type == 'tablefield' then
+ local field = source.field
+ local name = field[1]
+ if matchKey(key, name) then
+ results[#results+1] = {
+ name = name,
+ kind = define.SymbolKind.Field,
+ uri = uri,
+ range = { field.start, field.finish },
+ }
+ end
+ elseif source.type == 'setmethod' then
+ local method = source.method
+ local name = method[1]
+ if matchKey(key, name) then
+ results[#results+1] = {
+ name = name,
+ kind = define.SymbolKind.Method,
+ uri = uri,
+ range = { method.start, method.finish },
+ }
+ end
+ end
+end
+
+local function searchFile(uri, key, results)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSource(ast.ast, function (source)
+ buildSource(uri, source, key, results)
+ end)
+end
+
+return function (key)
+ local results = {}
+
+ for uri in files.eachFile() do
+ searchFile(files.getOriginUri(uri), key, results)
+ if #results > 1000 then
+ break
+ end
+ await.delay()
+ end
+
+ return results
+end
diff --git a/script/doctor.lua b/script/doctor.lua
new file mode 100644
index 00000000..08ec69cf
--- /dev/null
+++ b/script/doctor.lua
@@ -0,0 +1,380 @@
+local type = type
+local next = next
+local ipairs = ipairs
+local rawget = rawget
+local pcall = pcall
+local getregistry = debug.getregistry
+local getmetatable = debug.getmetatable
+local getupvalue = debug.getupvalue
+local getuservalue = debug.getuservalue
+local getlocal = debug.getlocal
+local getinfo = debug.getinfo
+local maxinterger = math.maxinteger
+local mathType = math.type
+local tableConcat = table.concat
+local _G = _G
+local registry = getregistry()
+local tableSort = table.sort
+
+_ENV = nil
+
+local m = {}
+
+local function getTostring(obj)
+ local mt = getmetatable(obj)
+ if not mt then
+ return nil
+ end
+ local toString = rawget(mt, '__tostring')
+ if not toString then
+ return nil
+ end
+ local suc, str = pcall(toString, obj)
+ if not suc then
+ return nil
+ end
+ if type(str) ~= 'string' then
+ return nil
+ end
+ return str
+end
+
+local function formatName(obj)
+ local tp = type(obj)
+ if tp == 'nil' then
+ return 'nil:nil'
+ elseif tp == 'boolean' then
+ if obj == true then
+ return 'boolean:true'
+ else
+ return 'boolean:false'
+ end
+ elseif tp == 'number' then
+ if mathType(obj) == 'integer' then
+ return ('number:%d'):format(obj)
+ else
+ -- 如果浮点数可以完全表示为整数,那么就转换为整数
+ local str = ('%.10f'):format(obj):gsub('%.?[0]+$', '')
+ if str:find('.', 1, true) then
+ -- 如果浮点数不能表示为整数,那么再加上它的精确表示法
+ str = ('%s(%q)'):format(str, obj)
+ end
+ return 'number:' .. str
+ end
+ elseif tp == 'string' then
+ local str = ('%q'):format(obj)
+ if #str > 100 then
+ local new = ('%s...(len=%d)'):format(str:sub(1, 100), #str)
+ if #new < #str then
+ str = new
+ end
+ end
+ return 'string:' .. str
+ elseif tp == 'function' then
+ local info = getinfo(obj, 'S')
+ if info.what == 'c' then
+ return ('function:%p(C)'):format(obj)
+ elseif info.what == 'main' then
+ return ('function:%p(main)'):format(obj)
+ else
+ return ('function:%p(%s:%d-%d)'):format(obj, info.source, info.linedefined, info.lastlinedefined)
+ end
+ elseif tp == 'table' then
+ local id = getTostring(obj)
+ if not id then
+ if obj == _G then
+ id = '_G'
+ elseif obj == registry then
+ id = 'registry'
+ end
+ end
+ if id then
+ return ('table:%p(%s)'):format(obj, id)
+ else
+ return ('table:%p'):format(obj)
+ end
+ elseif tp == 'userdata' then
+ local id = getTostring(obj)
+ if id then
+ return ('userdata:%p(%s)'):format(obj, id)
+ else
+ return ('userdata:%p'):format(obj)
+ end
+ else
+ return ('%s:%p'):format(tp, obj)
+ end
+end
+
+--- 内存快照
+---@return table
+function m.snapshot()
+ local mark = {}
+ local find
+
+ local function findTable(t, result)
+ result = result or {}
+ local mt = getmetatable(t)
+ local wk, wv
+ if mt then
+ local mode = rawget(mt, '__mode')
+ if type(mode) == 'string' then
+ if mode:find('k', 1, true) then
+ wk = true
+ end
+ if mode:find('v', 1, true) then
+ wv = true
+ end
+ end
+ end
+ for k, v in next, t do
+ if not wk then
+ local keyInfo = find(k)
+ if keyInfo then
+ result[#result+1] = {
+ type = 'key',
+ name = formatName(k),
+ info = keyInfo,
+ }
+ end
+ end
+ if not wv then
+ local valueInfo = find(v)
+ if valueInfo then
+ result[#result+1] = {
+ type = 'field',
+ name = formatName(k) .. '|' .. formatName(v),
+ info = valueInfo,
+ }
+ end
+ end
+ end
+ local MTInfo = find(getmetatable(t))
+ if MTInfo then
+ result[#result+1] = {
+ type = 'metatable',
+ name = '',
+ info = MTInfo,
+ }
+ end
+ if #result == 0 then
+ return nil
+ end
+ return result
+ end
+
+ local function findFunction(f, result, trd, stack)
+ result = result or {}
+ for i = 1, maxinterger do
+ local n, v = getupvalue(f, i)
+ if not n then
+ break
+ end
+ local valueInfo = find(v)
+ if valueInfo then
+ result[#result+1] = {
+ type = 'upvalue',
+ name = n,
+ info = valueInfo,
+ }
+ end
+ end
+ if trd then
+ for i = 1, maxinterger do
+ local n, l = getlocal(trd, stack, i)
+ if not n then
+ break
+ end
+ local valueInfo = find(l)
+ if valueInfo then
+ result[#result+1] = {
+ type = 'local',
+ name = n,
+ info = valueInfo,
+ }
+ end
+ end
+ end
+ if #result == 0 then
+ return nil
+ end
+ return result
+ end
+
+ local function findUserData(u, result)
+ result = result or {}
+ for i = 1, maxinterger do
+ local v, b = getuservalue(u, i)
+ if not b then
+ break
+ end
+ local valueInfo = find(v)
+ if valueInfo then
+ result[#result+1] = {
+ type = 'uservalue',
+ name = formatName(i),
+ info = valueInfo,
+ }
+ end
+ end
+ local MTInfo = find(getmetatable(u))
+ if MTInfo then
+ result[#result+1] = {
+ type = 'metatable',
+ name = '',
+ info = MTInfo,
+ }
+ end
+ if #result == 0 then
+ return nil
+ end
+ return result
+ end
+
+ local function findThread(trd, result)
+ -- 不查找主线程,主线程一定是临时的(视为弱引用)
+ if trd == registry[1] then
+ return nil
+ end
+ result = result or {}
+
+ for i = 1, maxinterger do
+ local info = getinfo(trd, i, 'Sf')
+ if not info then
+ break
+ end
+ local funcInfo = find(info.func, trd, i)
+ if funcInfo then
+ result[#result+1] = {
+ type = 'stack',
+ name = i .. '@' .. formatName(info.func),
+ info = funcInfo,
+ }
+ end
+ end
+
+ if #result == 0 then
+ return nil
+ end
+ return result
+ end
+
+ function find(obj, trd, stack)
+ if mark[obj] then
+ return mark[obj]
+ end
+ local tp = type(obj)
+ if tp == 'table' then
+ mark[obj] = {}
+ mark[obj] = findTable(obj, mark[obj])
+ elseif tp == 'function' then
+ mark[obj] = {}
+ mark[obj] = findFunction(obj, mark[obj], trd, stack)
+ elseif tp == 'userdata' then
+ mark[obj] = {}
+ mark[obj] = findUserData(obj, mark[obj])
+ elseif tp == 'thread' then
+ mark[obj] = {}
+ mark[obj] = findThread(obj, mark[obj])
+ else
+ return nil
+ end
+ if mark[obj] then
+ mark[obj].object = obj
+ end
+ return mark[obj]
+ end
+
+ return {
+ name = formatName(registry),
+ type = 'root',
+ info = find(registry),
+ }
+end
+
+--- 寻找对象的引用
+---@return string
+function m.catch(...)
+ local targets = {}
+ for _, target in ipairs {...} do
+ targets[target] = true
+ end
+ local report = m.snapshot()
+ local path = {}
+ local result = {}
+ local mark = {}
+
+ local function push()
+ result[#result+1] = tableConcat(path, ' => ')
+ end
+
+ local function search(t)
+ path[#path+1] = ('(%s)%s'):format(t.type, t.name)
+ local addTarget
+ if targets[t.info.object] then
+ targets[t.info.object] = nil
+ addTarget = t.info.object
+ push(t)
+ end
+ if not mark[t.info] then
+ mark[t.info] = true
+ for _, obj in ipairs(t.info) do
+ search(obj)
+ end
+ end
+ path[#path] = nil
+ if addTarget then
+ targets[addTarget] = true
+ end
+ end
+
+ search(report)
+
+ return result
+end
+
+--- 生成一个报告
+---@return string
+function m.report()
+ local snapshot = m.snapshot()
+ local cache = {}
+ local mark = {}
+
+ local function scan(t)
+ local obj = t.info.object
+ local tp = type(obj)
+ if tp == 'table'
+ or tp == 'userdata'
+ or tp == 'function'
+ or tp == 'string'
+ or tp == 'thread' then
+ local point = ('%p'):format(obj)
+ if not cache[point] then
+ cache[point] = {
+ point = point,
+ count = 0,
+ name = formatName(obj),
+ }
+ end
+ cache[point].count = cache[point].count + 1
+ end
+ if not mark[t.info] then
+ mark[t.info] = true
+ for _, child in ipairs(t.info) do
+ scan(child)
+ end
+ end
+ end
+
+ scan(snapshot)
+
+ local list = {}
+ for _, info in next, cache do
+ list[#list+1] = info
+ end
+ tableSort(list, function (a, b)
+ return a.name < b.name
+ end)
+ return list
+end
+
+return m
diff --git a/script/file-uri.lua b/script/file-uri.lua
new file mode 100644
index 00000000..ba44f2e7
--- /dev/null
+++ b/script/file-uri.lua
@@ -0,0 +1,89 @@
+local platform = require 'bee.platform'
+
+local escPatt = '[^%w%-%.%_%~%/]'
+
+local function esc(c)
+ return ('%%%02X'):format(c:byte())
+end
+
+local function normalize(str)
+ return str:gsub('%%(%x%x)', function (n)
+ return string.char(tonumber(n, 16))
+ end)
+end
+
+local m = {}
+
+-- c:\my\files --> file:///c%3A/my/files
+-- /usr/home --> file:///usr/home
+-- \\server\share\some\path --> file://server/share/some/path
+
+--- path -> uri
+---@param path string
+---@return string uri
+function m.encode(path)
+ local authority = ''
+ if platform.OS == 'Windows' then
+ path = path:gsub('\\', '/')
+ end
+
+ if path:sub(1, 2) == '//' then
+ local idx = path:find('/', 3)
+ if idx then
+ authority = path:sub(3, idx)
+ path = path:sub(idx + 1)
+ if path == '' then
+ path = '/'
+ end
+ else
+ authority = path:sub(3)
+ path = '/'
+ end
+ end
+
+ if path:sub(1, 1) ~= '/' then
+ path = '/' .. path
+ end
+
+ -- lower-case windows drive letters in /C:/fff or C:/fff
+ local start, finish, drive = path:find '/(%u):'
+ if drive then
+ path = path:sub(1, start) .. drive:lower() .. path:sub(finish, -1)
+ end
+
+ local uri = 'file://'
+ .. authority:gsub(escPatt, esc)
+ .. path:gsub(escPatt, esc)
+ return uri
+end
+
+-- file:///c%3A/my/files --> c:\my\files
+-- file:///usr/home --> /usr/home
+-- file://server/share/some/path --> \\server\share\some\path
+
+--- uri -> path
+---@param uri string
+---@return string path
+function m.decode(uri)
+ local scheme, authority, path = uri:match('([^:]*):?/?/?([^/]*)(.*)')
+ if not scheme then
+ return ''
+ end
+ scheme = normalize(scheme)
+ authority = normalize(authority)
+ path = normalize(path)
+ local value
+ if scheme == 'file' and #authority > 0 and #path > 1 then
+ value = '//' .. authority .. path
+ elseif path:match '/%a:' then
+ value = path:sub(2, 2):lower() .. path:sub(3)
+ else
+ value = path
+ end
+ if platform.OS == 'Windows' then
+ value = value:gsub('/', '\\')
+ end
+ return value
+end
+
+return m
diff --git a/script/files.lua b/script/files.lua
new file mode 100644
index 00000000..4d34568d
--- /dev/null
+++ b/script/files.lua
@@ -0,0 +1,438 @@
+local platform = require 'bee.platform'
+local config = require 'config'
+local glob = require 'glob'
+local furi = require 'file-uri'
+local parser = require 'parser'
+local proto = require 'proto'
+local lang = require 'language'
+local await = require 'await'
+local timer = require 'timer'
+
+local m = {}
+
+m.openMap = {}
+m.libraryMap = {}
+m.fileMap = {}
+m.watchList = {}
+m.notifyCache = {}
+m.assocVersion = -1
+m.assocMatcher = nil
+m.globalVersion = 0
+m.linesMap = setmetatable({}, { __mode = 'v' })
+m.astMap = setmetatable({}, { __mode = 'v' })
+
+--- 打开文件
+---@param uri string
+function m.open(uri)
+ local originUri = uri
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ m.openMap[uri] = true
+ m.onWatch('open', originUri)
+end
+
+--- 关闭文件
+---@param uri string
+function m.close(uri)
+ local originUri = uri
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ m.openMap[uri] = nil
+ m.onWatch('close', originUri)
+end
+
+--- 是否打开
+---@param uri string
+---@return boolean
+function m.isOpen(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ return m.openMap[uri] == true
+end
+
+--- 标记为库文件
+function m.setLibraryPath(uri, libraryPath)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ m.libraryMap[uri] = libraryPath
+end
+
+--- 是否是库文件
+function m.isLibrary(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ return m.libraryMap[uri] ~= nil
+end
+
+--- 获取库文件的根目录
+function m.getLibraryPath(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ return m.libraryMap[uri]
+end
+
+function m.flushAllLibrary()
+ m.libraryMap = {}
+end
+
+--- 是否存在
+---@return boolean
+function m.exists(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ return m.fileMap[uri] ~= nil
+end
+
+function m.asKey(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ return uri
+end
+
+--- 设置文件文本
+---@param uri string
+---@param text string
+function m.setText(uri, text)
+ if not text then
+ return
+ end
+ local originUri = uri
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ local create
+ if not m.fileMap[uri] then
+ m.fileMap[uri] = {
+ uri = originUri,
+ version = 0,
+ }
+ create = true
+ end
+ local file = m.fileMap[uri]
+ if file.text == text then
+ return
+ end
+ file.text = text
+ m.linesMap[uri] = nil
+ m.astMap[uri] = nil
+ file.cache = {}
+ file.cacheActiveTime = math.huge
+ file.version = file.version + 1
+ m.globalVersion = m.globalVersion + 1
+ await.close('files.version')
+ if create then
+ m.onWatch('create', originUri)
+ end
+ m.onWatch('update', originUri)
+end
+
+--- 获取文件版本
+function m.getVersion(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ local file = m.fileMap[uri]
+ if not file then
+ return nil
+ end
+ return file.version
+end
+
+--- 获取文件文本
+---@param uri string
+---@return string text
+function m.getText(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ local file = m.fileMap[uri]
+ if not file then
+ return nil
+ end
+ return file.text
+end
+
+--- 移除文件
+---@param uri string
+function m.remove(uri)
+ local originUri = uri
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ local file = m.fileMap[uri]
+ if not file then
+ return
+ end
+ m.fileMap[uri] = nil
+
+ m.globalVersion = m.globalVersion + 1
+ await.close('files.version')
+ m.onWatch('remove', originUri)
+end
+
+--- 移除所有文件
+function m.removeAll()
+ m.globalVersion = m.globalVersion + 1
+ await.close('files.version')
+ for uri in pairs(m.fileMap) do
+ if not m.libraryMap[uri] then
+ m.fileMap[uri] = nil
+ m.astMap[uri] = nil
+ m.linesMap[uri] = nil
+ m.onWatch('remove', uri)
+ end
+ end
+ --m.notifyCache = {}
+end
+
+--- 移除所有关闭的文件
+function m.removeAllClosed()
+ m.globalVersion = m.globalVersion + 1
+ await.close('files.version')
+ for uri in pairs(m.fileMap) do
+ if not m.openMap[uri]
+ and not m.libraryMap[uri] then
+ m.fileMap[uri] = nil
+ m.astMap[uri] = nil
+ m.linesMap[uri] = nil
+ m.onWatch('remove', uri)
+ end
+ end
+ --m.notifyCache = {}
+end
+
+--- 遍历文件
+function m.eachFile()
+ return pairs(m.fileMap)
+end
+
+function m.compileAst(uri, text)
+ if not m.isOpen(uri) and #text >= config.config.workspace.preloadFileSize * 1000 then
+ if not m.notifyCache['preloadFileSize'] then
+ m.notifyCache['preloadFileSize'] = {}
+ m.notifyCache['skipLargeFileCount'] = 0
+ end
+ if not m.notifyCache['preloadFileSize'][uri] then
+ m.notifyCache['preloadFileSize'][uri] = true
+ m.notifyCache['skipLargeFileCount'] = m.notifyCache['skipLargeFileCount'] + 1
+ if m.notifyCache['skipLargeFileCount'] <= 3 then
+ local ws = require 'workspace'
+ proto.notify('window/showMessage', {
+ type = 3,
+ message = lang.script('WORKSPACE_SKIP_LARGE_FILE'
+ , ws.getRelativePath(uri)
+ , config.config.workspace.preloadFileSize
+ , #text / 1000
+ ),
+ })
+ end
+ end
+ return nil
+ end
+ local clock = os.clock()
+ local state, err = parser:compile(text
+ , 'lua'
+ , config.config.runtime.version
+ , {
+ special = config.config.runtime.special,
+ }
+ )
+ local passed = os.clock() - clock
+ if passed > 0.1 then
+ log.warn(('Compile [%s] takes [%.3f] sec, size [%.3f] kb.'):format(uri, passed, #text / 1000))
+ end
+ if state then
+ state.uri = uri
+ state.ast.uri = uri
+ if config.config.luadoc.enable then
+ parser:luadoc(state)
+ end
+ return state
+ else
+ log.error(err)
+ return nil
+ end
+end
+
+--- 获取文件语法树
+---@param uri string
+---@return table ast
+function m.getAst(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ if uri ~= '' and not m.isLua(uri) then
+ return nil
+ end
+ local file = m.fileMap[uri]
+ if not file then
+ return nil
+ end
+ local ast = m.astMap[uri]
+ if not ast then
+ ast = m.compileAst(uri, file.text)
+ m.astMap[uri] = ast
+ end
+ file.cacheActiveTime = timer.clock()
+ return ast
+end
+
+--- 获取文件行信息
+---@param uri string
+---@return table lines
+function m.getLines(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ local file = m.fileMap[uri]
+ if not file then
+ return nil
+ end
+ local lines = m.linesMap[uri]
+ if not lines then
+ lines = parser:lines(file.text)
+ m.linesMap[uri] = lines
+ end
+ return lines
+end
+
+--- 获取原始uri
+function m.getOriginUri(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ local file = m.fileMap[uri]
+ if not file then
+ return nil
+ end
+ return file.uri
+end
+
+function m.getUri(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ return uri
+end
+
+--- 获取文件的自定义缓存信息(在文件内容更新后自动失效)
+function m.getCache(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ local file = m.fileMap[uri]
+ if not file then
+ return nil
+ end
+ file.cacheActiveTime = timer.clock()
+ return file.cache
+end
+
+--- 判断文件名相等
+function m.eq(a, b)
+ if platform.OS == 'Windows' then
+ return a:lower() == b:lower()
+ else
+ return a == b
+ end
+end
+
+--- 获取文件关联
+function m.getAssoc()
+ if m.assocVersion == config.version then
+ return m.assocMatcher
+ end
+ m.assocVersion = config.version
+ local patt = {}
+ for k, v in pairs(config.other.associations) do
+ if m.eq(v, 'lua') then
+ patt[#patt+1] = k
+ end
+ end
+ m.assocMatcher = glob.glob(patt)
+ if platform.OS == 'Windows' then
+ m.assocMatcher:setOption 'ignoreCase'
+ end
+ return m.assocMatcher
+end
+
+--- 判断是否是Lua文件
+---@param uri string
+---@return boolean
+function m.isLua(uri)
+ local ext = uri:match '%.([^%.%/%\\]-)$'
+ if not ext then
+ return false
+ end
+ if m.eq(ext, 'lua') then
+ return true
+ end
+ local matcher = m.getAssoc()
+ local path = furi.decode(uri)
+ return matcher(path)
+end
+
+--- 注册事件
+function m.watch(callback)
+ m.watchList[#m.watchList+1] = callback
+end
+
+function m.onWatch(ev, ...)
+ for _, callback in ipairs(m.watchList) do
+ callback(ev, ...)
+ end
+end
+
+function m.flushCache()
+ for uri, file in pairs(m.fileMap) do
+ file.cacheActiveTime = math.huge
+ m.linesMap[uri] = nil
+ m.astMap[uri] = nil
+ file.cache = {}
+ end
+end
+
+function m.flushFileCache(uri)
+ if platform.OS == 'Windows' then
+ uri = uri:lower()
+ end
+ local file = m.fileMap[uri]
+ if not file then
+ return
+ end
+ file.cacheActiveTime = math.huge
+ m.linesMap[uri] = nil
+ m.astMap[uri] = nil
+ file.cache = {}
+end
+
+local function init()
+ --TODO 可以清空文件缓存,之后看要不要启用吧
+ --timer.loop(10, function ()
+ -- local list = {}
+ -- for _, file in pairs(m.fileMap) do
+ -- if timer.clock() - file.cacheActiveTime > 10.0 then
+ -- file.cacheActiveTime = math.huge
+ -- file.ast = nil
+ -- file.cache = {}
+ -- list[#list+1] = file.uri
+ -- end
+ -- end
+ -- if #list > 0 then
+ -- log.info('Flush file caches:', #list, '\n', table.concat(list, '\n'))
+ -- collectgarbage()
+ -- end
+ --end)
+end
+
+xpcall(init, log.error)
+
+return m
diff --git a/script/fs-utility.lua b/script/fs-utility.lua
new file mode 100644
index 00000000..42041734
--- /dev/null
+++ b/script/fs-utility.lua
@@ -0,0 +1,559 @@
+local fs = require 'bee.filesystem'
+local platform = require 'bee.platform'
+
+local type = type
+local ioOpen = io.open
+local pcall = pcall
+local pairs = pairs
+local setmetatable = setmetatable
+local next = next
+local ipairs = ipairs
+local tostring = tostring
+local tableSort = table.sort
+
+_ENV = nil
+
+local m = {}
+--- 读取文件
+---@param path string
+function m.loadFile(path)
+ if type(path) ~= 'string' then
+ path = path:string()
+ end
+ local f, e = ioOpen(path, 'rb')
+ if not f then
+ return nil, e
+ end
+ if f:read(3) ~= '\xEF\xBB\xBF' then
+ f:seek("set")
+ end
+ local buf = f:read 'a'
+ f:close()
+ return buf
+end
+
+--- 写入文件
+---@param path string
+---@param content string
+function m.saveFile(path, content)
+ if type(path) ~= 'string' then
+ path = path:string()
+ end
+ local f, e = ioOpen(path, "wb")
+
+ if f then
+ f:write(content)
+ f:close()
+ return true
+ else
+ return false, e
+ end
+end
+
+local function buildOptional(optional)
+ optional = optional or {}
+ optional.add = optional.add or {}
+ optional.del = optional.del or {}
+ optional.mod = optional.mod or {}
+ optional.err = optional.err or {}
+ return optional
+end
+
+local function split(str, sep)
+ local t = {}
+ local current = 1
+ while current <= #str do
+ local s, e = str:find(sep, current)
+ if not s then
+ t[#t+1] = str:sub(current)
+ break
+ end
+ if s > 1 then
+ t[#t+1] = str:sub(current, s - 1)
+ end
+ current = e + 1
+ end
+ return t
+end
+
+local dfs = {}
+dfs.__index = dfs
+dfs.type = 'dummy'
+dfs.path = ''
+
+function m.dummyFS(t)
+ return setmetatable({
+ files = t or {},
+ }, dfs)
+end
+
+function dfs:__tostring()
+ return 'dummy:' .. tostring(self.path)
+end
+
+function dfs:__div(filename)
+ if type(filename) ~= 'string' then
+ filename = filename:string()
+ end
+ local new = m.dummyFS(self.files)
+ if self.path:sub(-1):match '[^/\\]' then
+ new.path = self.path .. '\\' .. filename
+ else
+ new.path = self.path .. filename
+ end
+ return new
+end
+
+function dfs:_open(index)
+ local paths = split(self.path, '[/\\]')
+ local current = self.files
+ if not index then
+ index = #paths
+ elseif index < 0 then
+ index = #paths + index + 1
+ end
+ for i = 1, index do
+ local path = paths[i]
+ if current[path] then
+ current = current[path]
+ else
+ return nil
+ end
+ end
+ return current
+end
+
+function dfs:_filename()
+ return self.path:match '[^/\\]+$'
+end
+
+function dfs:parent_path()
+ local new = m.dummyFS(self.files)
+ if self.path:find('[/\\]') then
+ new.path = self.path:gsub('[/\\]+[^/\\]*$', '')
+ else
+ new.path = ''
+ end
+ return new
+end
+
+function dfs:filename()
+ local new = m.dummyFS(self.files)
+ new.path = self:_filename()
+ return new
+end
+
+function dfs:string()
+ return self.path
+end
+
+function dfs:list_directory()
+ local dir = self:_open()
+ if type(dir) ~= 'table' then
+ return function () end
+ end
+ local keys = {}
+ for k in pairs(dir) do
+ keys[#keys+1] = k
+ end
+ tableSort(keys)
+ local i = 0
+ return function ()
+ i = i + 1
+ local k = keys[i]
+ if not k then
+ return nil
+ end
+ return self / k
+ end
+end
+
+function dfs:isDirectory()
+ local target = self:_open()
+ if type(target) == 'table' then
+ return true
+ end
+ return false
+end
+
+function dfs:remove()
+ local dir = self:_open(-2)
+ local filename = self:_filename()
+ if not filename then
+ return
+ end
+ dir[filename] = nil
+end
+
+function dfs:exists()
+ local target = self:_open()
+ return target ~= nil
+end
+
+function dfs:createDirectories(path)
+ if type(path) ~= 'string' then
+ path = path:string()
+ end
+ local paths = split(path, '[/\\]')
+ local current = self.files
+ for i = 1, #paths do
+ local sub = paths[i]
+ if current[sub] then
+ if type(current[sub]) ~= 'table' then
+ return false
+ end
+ else
+ current[sub] = {}
+ end
+ current = current[sub]
+ end
+ return true
+end
+
+function dfs:saveFile(path, text)
+ if type(path) ~= 'string' then
+ path = path:string()
+ end
+ local temp = m.dummyFS(self.files)
+ temp.path = path
+ local dir = temp:_open(-2)
+ if not dir then
+ return false, '无法打开:' .. path
+ end
+ local filename = temp:_filename()
+ if not filename then
+ return false, '无法打开:' .. path
+ end
+ if type(dir[filename]) == 'table' then
+ return false, '无法打开:' .. path
+ end
+ dir[filename] = text
+end
+
+local function fsAbsolute(path, optional)
+ if type(path) == 'string' then
+ local suc, res = pcall(fs.path, path)
+ if not suc then
+ optional.err[#optional.err+1] = res
+ return nil
+ end
+ path = res
+ elseif type(path) == 'table' then
+ return path
+ end
+ local suc, res = pcall(fs.absolute, path)
+ if not suc then
+ optional.err[#optional.err+1] = res
+ return nil
+ end
+ return res
+end
+
+local function fsIsDirectory(path, optional)
+ if path.type == 'dummy' then
+ return path:isDirectory()
+ end
+ local suc, res = pcall(fs.is_directory, path)
+ if not suc then
+ optional.err[#optional.err+1] = res
+ return false
+ end
+ return res
+end
+
+local function fsRemove(path, optional)
+ if path.type == 'dummy' then
+ return path:remove()
+ end
+ local suc, res = pcall(fs.remove, path)
+ if not suc then
+ optional.err[#optional.err+1] = res
+ end
+ optional.del[#optional.del+1] = path:string()
+end
+
+local function fsExists(path, optional)
+ if path.type == 'dummy' then
+ return path:exists()
+ end
+ local suc, res = pcall(fs.exists, path)
+ if not suc then
+ optional.err[#optional.err+1] = res
+ return false
+ end
+ return res
+end
+
+local function fsSave(path, text, optional)
+ if path.type == 'dummy' then
+ local dir = path:_open(-2)
+ if not dir then
+ optional.err[#optional.err+1] = '无法打开:' .. path:string()
+ return false
+ end
+ local filename = path:_filename()
+ if not filename then
+ optional.err[#optional.err+1] = '无法打开:' .. path:string()
+ return false
+ end
+ if type(dir[filename]) == 'table' then
+ optional.err[#optional.err+1] = '无法打开:' .. path:string()
+ return false
+ end
+ dir[filename] = text
+ else
+ local suc, err = m.saveFile(path, text)
+ if suc then
+ return true
+ end
+ optional.err[#optional.err+1] = err
+ return false
+ end
+end
+
+local function fsLoad(path, optional)
+ if path.type == 'dummy' then
+ local text = path:_open()
+ if type(text) == 'string' then
+ return text
+ else
+ optional.err[#optional.err+1] = '无法打开:' .. path:string()
+ return nil
+ end
+ else
+ local text, err = m.loadFile(path)
+ if text then
+ return text
+ else
+ optional.err[#optional.err+1] = err
+ return nil
+ end
+ end
+end
+
+local function fsCopy(source, target, optional)
+ if source.type == 'dummy' then
+ local sourceText = source:_open()
+ if not sourceText then
+ optional.err[#optional.err+1] = '无法打开:' .. source:string()
+ return false
+ end
+ return fsSave(target, sourceText, optional)
+ else
+ if target.type == 'dummy' then
+ local sourceText, err = m.loadFile(source)
+ if not sourceText then
+ optional.err[#optional.err+1] = err
+ return false
+ end
+ return fsSave(target, sourceText, optional)
+ else
+ local suc, res = pcall(fs.copy_file, source, target, true)
+ if not suc then
+ optional.err[#optional.err+1] = res
+ return false
+ end
+ end
+ end
+ return true
+end
+
+local function fsCreateDirectories(path, optional)
+ if path.type == 'dummy' then
+ return path:createDirectories()
+ end
+ local suc, res = pcall(fs.create_directories, path)
+ if not suc then
+ optional.err[#optional.err+1] = res
+ return false
+ end
+ return true
+end
+
+local function fileRemove(path, optional)
+ if optional.onRemove and optional.onRemove(path) == false then
+ return
+ end
+ if fsIsDirectory(path, optional) then
+ for child in path:list_directory() do
+ fileRemove(child, optional)
+ end
+ end
+ if fsRemove(path, optional) then
+ optional.del[#optional.del+1] = path:string()
+ end
+end
+
+local function fileCopy(source, target, optional)
+ if optional.onCopy and optional.onCopy(source, target) == false then
+ return
+ end
+ local isDir1 = fsIsDirectory(source, optional)
+ local isDir2 = fsIsDirectory(target, optional)
+ local isExists = fsExists(target, optional)
+ if isDir1 then
+ if isDir2 or fsCreateDirectories(target, optional) then
+ for filePath in source:list_directory() do
+ local name = filePath:filename():string()
+ fileCopy(filePath, target / name, optional)
+ end
+ end
+ else
+ if isExists and not isDir2 then
+ local buf1 = fsLoad(source, optional)
+ local buf2 = fsLoad(target, optional)
+ if buf1 and buf2 then
+ if buf1 ~= buf2 then
+ if fsCopy(source, target, optional) then
+ optional.mod[#optional.mod+1] = target:string()
+ end
+ end
+ end
+ else
+ if fsCopy(source, target, optional) then
+ optional.add[#optional.add+1] = target:string()
+ end
+ end
+ end
+end
+
+local function fileSync(source, target, optional)
+ if optional.onSync and optional.onSync(source, target) == false then
+ return
+ end
+ local isDir1 = fsIsDirectory(source, optional)
+ local isDir2 = fsIsDirectory(target, optional)
+ local isExists = fsExists(target, optional)
+ if isDir1 then
+ if isDir2 then
+ local fileList = m.fileList()
+ for filePath in target:list_directory() do
+ fileList[filePath] = true
+ end
+ for filePath in source:list_directory() do
+ local name = filePath:filename():string()
+ local targetPath = target / name
+ fileSync(filePath, targetPath, optional)
+ fileList[targetPath] = nil
+ end
+ for path in pairs(fileList) do
+ fileRemove(path, optional)
+ end
+ else
+ if isExists then
+ fileRemove(target, optional)
+ end
+ if fsCreateDirectories(target) then
+ for filePath in source:list_directory() do
+ local name = filePath:filename():string()
+ fileCopy(filePath, target / name, optional)
+ end
+ end
+ end
+ else
+ if isDir2 then
+ fileRemove(target, optional)
+ end
+ if isExists then
+ local buf1 = fsLoad(source, optional)
+ local buf2 = fsLoad(target, optional)
+ if buf1 and buf2 then
+ if buf1 ~= buf2 then
+ if fsCopy(source, target, optional) then
+ optional.mod[#optional.mod+1] = target:string()
+ end
+ end
+ end
+ else
+ if fsCopy(source, target, optional) then
+ optional.add[#optional.add+1] = target:string()
+ end
+ end
+ end
+end
+
+--- 文件列表
+function m.fileList(optional)
+ optional = optional or buildOptional(optional)
+ local os = platform.OS
+ local keyMap = {}
+ local fileList = {}
+ local function computeKey(path)
+ path = fsAbsolute(path, optional)
+ if not path then
+ return nil
+ end
+ local key
+ if os == 'Windows' then
+ key = path:string():lower()
+ else
+ key = path:string()
+ end
+ return key
+ end
+ return setmetatable({}, {
+ __index = function (_, path)
+ local key = computeKey(path)
+ return fileList[key]
+ end,
+ __newindex = function (_, path, value)
+ local key = computeKey(path)
+ if not key then
+ return
+ end
+ if value == nil then
+ keyMap[key] = nil
+ else
+ keyMap[key] = path
+ fileList[key] = value
+ end
+ end,
+ __pairs = function ()
+ local key, path
+ return function ()
+ key, path = next(keyMap, key)
+ return path, fileList[key]
+ end
+ end,
+ })
+end
+
+--- 删除文件(夹)
+function m.fileRemove(path, optional)
+ optional = buildOptional(optional)
+ path = fsAbsolute(path, optional)
+
+ fileRemove(path, optional)
+
+ return optional
+end
+
+--- 复制文件(夹)
+---@param source string
+---@param target string
+---@return table
+function m.fileCopy(source, target, optional)
+ optional = buildOptional(optional)
+ source = fsAbsolute(source, optional)
+ target = fsAbsolute(target, optional)
+
+ fileCopy(source, target, optional)
+
+ return optional
+end
+
+--- 同步文件(夹)
+---@param source string
+---@param target string
+---@return table
+function m.fileSync(source, target, optional)
+ optional = buildOptional(optional)
+ source = fsAbsolute(source, optional)
+ target = fsAbsolute(target, optional)
+
+ fileSync(source, target, optional)
+
+ return optional
+end
+
+return m
diff --git a/script/glob/gitignore.lua b/script/glob/gitignore.lua
new file mode 100644
index 00000000..f98a2f31
--- /dev/null
+++ b/script/glob/gitignore.lua
@@ -0,0 +1,221 @@
+local m = require 'lpeglabel'
+local matcher = require 'glob.matcher'
+
+local function prop(name, pat)
+ return m.Cg(m.Cc(true), name) * pat
+end
+
+local function object(type, pat)
+ return m.Ct(
+ m.Cg(m.Cc(type), 'type') *
+ m.Cg(pat, 'value')
+ )
+end
+
+local function expect(p, err)
+ return p + m.T(err)
+end
+
+local parser = m.P {
+ 'Main',
+ ['Sp'] = m.S(' \t')^0,
+ ['Slash'] = m.S('/\\')^1,
+ ['Main'] = m.Ct(m.V'Sp' * m.P'{' * m.V'Pattern' * (',' * expect(m.V'Pattern', 'Miss exp after ","'))^0 * m.P'}')
+ + m.Ct(m.V'Pattern')
+ + m.T'Main Failed'
+ ,
+ ['Pattern'] = m.Ct(m.V'Sp' * prop('neg', m.P'!') * expect(m.V'Unit', 'Miss exp after "!"'))
+ + m.Ct(m.V'Unit')
+ ,
+ ['NeedRoot'] = prop('root', (m.P'.' * m.V'Slash' + m.V'Slash')),
+ ['Unit'] = m.V'Sp' * m.V'NeedRoot'^-1 * expect(m.V'Exp', 'Miss exp') * m.V'Sp',
+ ['Exp'] = m.V'Sp' * (m.V'FSymbol' + object('/', m.V'Slash') + m.V'Word')^0 * m.V'Sp',
+ ['Word'] = object('word', m.Ct((m.V'CSymbol' + m.V'Char' - m.V'FSymbol')^1)),
+ ['CSymbol'] = object('*', m.P'*')
+ + object('?', m.P'?')
+ + object('[]', m.V'Range')
+ ,
+ ['Char'] = object('char', (1 - m.S',{}[]*?/\\')^1),
+ ['FSymbol'] = object('**', m.P'**'),
+ ['Range'] = m.P'[' * m.Ct(m.V'RangeUnit'^0) * m.P']'^-1,
+ ['RangeUnit'] = m.Ct(- m.P']' * m.C(m.P(1)) * (m.P'-' * - m.P']' * m.C(m.P(1)))^-1),
+}
+
+local mt = {}
+mt.__index = mt
+mt.__name = 'gitignore'
+
+function mt:addPattern(pat)
+ if type(pat) ~= 'string' then
+ return
+ end
+ self.pattern[#self.pattern+1] = pat
+ if self.options.ignoreCase then
+ pat = pat:lower()
+ end
+ local states, err = parser:match(pat)
+ if not states then
+ self.errors[#self.errors+1] = {
+ pattern = pat,
+ message = err
+ }
+ return
+ end
+ for _, state in ipairs(states) do
+ self.matcher[#self.matcher+1] = matcher(state)
+ end
+end
+
+function mt:setOption(op, val)
+ if val == nil then
+ val = true
+ end
+ self.options[op] = val
+end
+
+---@param key string | "'type'" | "'list'"
+---@param func function | "function (path) end"
+function mt:setInterface(key, func)
+ if type(func) ~= 'function' then
+ return
+ end
+ self.interface[key] = func
+end
+
+function mt:callInterface(name, ...)
+ local func = self.interface[name]
+ return func(...)
+end
+
+function mt:hasInterface(name)
+ return self.interface[name] ~= nil
+end
+
+function mt:checkDirectory(catch, path, matcher)
+ if not self:hasInterface 'type' then
+ return true
+ end
+ if not matcher:isNeedDirectory() then
+ return true
+ end
+ if #catch < #path then
+ -- if path is 'a/b/c' and catch is 'a/b'
+ -- then the catch must be a directory
+ return true
+ else
+ return self:callInterface('type', path) == 'directory'
+ end
+end
+
+function mt:simpleMatch(path)
+ for i = #self.matcher, 1, -1 do
+ local matcher = self.matcher[i]
+ local catch = matcher(path)
+ if catch and self:checkDirectory(catch, path, matcher) then
+ if matcher:isNegative() then
+ return false
+ else
+ return true
+ end
+ end
+ end
+ return nil
+end
+
+function mt:finishMatch(path)
+ local paths = {}
+ for filename in path:gmatch '[^/\\]+' do
+ paths[#paths+1] = filename
+ end
+ for i = 1, #paths do
+ local newPath = table.concat(paths, '/', 1, i)
+ local passed = self:simpleMatch(newPath)
+ if passed == true then
+ return true
+ elseif passed == false then
+ return false
+ end
+ end
+ return false
+end
+
+function mt:scan(callback)
+ local files = {}
+ if type(callback) ~= 'function' then
+ callback = nil
+ end
+ local list = {}
+ local result = self:callInterface('list', '')
+ if type(result) ~= 'table' then
+ return files
+ end
+ for _, path in ipairs(result) do
+ list[#list+1] = path:match '([^/\\]+)[/\\]*$'
+ end
+ while #list > 0 do
+ local current = list[#list]
+ if not current then
+ break
+ end
+ list[#list] = nil
+ if not self:simpleMatch(current) then
+ local fileType = self:callInterface('type', current)
+ if fileType == 'file' then
+ if callback then
+ callback(current)
+ end
+ files[#files+1] = current
+ elseif fileType == 'directory' then
+ local result = self:callInterface('list', current)
+ if type(result) == 'table' then
+ for _, path in ipairs(result) do
+ local filename = path:match '([^/\\]+)[/\\]*$'
+ if filename then
+ list[#list+1] = current .. '/' .. filename
+ end
+ end
+ end
+ end
+ end
+ end
+ return files
+end
+
+function mt:__call(path)
+ if self.options.ignoreCase then
+ path = path:lower()
+ end
+ return self:finishMatch(path)
+end
+
+return function (pattern, options, interface)
+ local self = setmetatable({
+ pattern = {},
+ options = {},
+ matcher = {},
+ errors = {},
+ interface = {},
+ }, mt)
+
+ if type(pattern) == 'table' then
+ for _, pat in ipairs(pattern) do
+ self:addPattern(pat)
+ end
+ else
+ self:addPattern(pattern)
+ end
+
+ if type(options) == 'table' then
+ for op, val in pairs(options) do
+ self:setOption(op, val)
+ end
+ end
+
+ if type(interface) == 'table' then
+ for key, func in pairs(interface) do
+ self:setInterface(key, func)
+ end
+ end
+
+ return self
+end
diff --git a/script/glob/glob.lua b/script/glob/glob.lua
new file mode 100644
index 00000000..aa8923f3
--- /dev/null
+++ b/script/glob/glob.lua
@@ -0,0 +1,122 @@
+local m = require 'lpeglabel'
+local matcher = require 'glob.matcher'
+
+local function prop(name, pat)
+ return m.Cg(m.Cc(true), name) * pat
+end
+
+local function object(type, pat)
+ return m.Ct(
+ m.Cg(m.Cc(type), 'type') *
+ m.Cg(pat, 'value')
+ )
+end
+
+local function expect(p, err)
+ return p + m.T(err)
+end
+
+local parser = m.P {
+ 'Main',
+ ['Sp'] = m.S(' \t')^0,
+ ['Slash'] = m.S('/\\')^1,
+ ['Main'] = m.Ct(m.V'Sp' * m.P'{' * m.V'Pattern' * (',' * expect(m.V'Pattern', 'Miss exp after ","'))^0 * m.P'}')
+ + m.Ct(m.V'Pattern')
+ + m.T'Main Failed'
+ ,
+ ['Pattern'] = m.Ct(m.V'Sp' * prop('neg', m.P'!') * expect(m.V'Unit', 'Miss exp after "!"'))
+ + m.Ct(m.V'Unit')
+ ,
+ ['NeedRoot'] = prop('root', (m.P'.' * m.V'Slash' + m.V'Slash')),
+ ['Unit'] = m.V'Sp' * m.V'NeedRoot'^-1 * expect(m.V'Exp', 'Miss exp') * m.V'Sp',
+ ['Exp'] = m.V'Sp' * (m.V'FSymbol' + object('/', m.V'Slash') + m.V'Word')^0 * m.V'Sp',
+ ['Word'] = object('word', m.Ct((m.V'CSymbol' + m.V'Char' - m.V'FSymbol')^1)),
+ ['CSymbol'] = object('*', m.P'*')
+ + object('?', m.P'?')
+ + object('[]', m.V'Range')
+ ,
+ ['Char'] = object('char', (1 - m.S',{}[]*?/\\')^1),
+ ['FSymbol'] = object('**', m.P'**'),
+ ['RangeWord'] = 1 - m.P']',
+ ['Range'] = m.P'[' * m.Ct(m.V'RangeUnit'^0) * m.P']'^-1,
+ ['RangeUnit'] = m.Ct(m.C(m.V'RangeWord') * m.P'-' * m.C(m.V'RangeWord'))
+ + m.V'RangeWord',
+}
+
+local mt = {}
+mt.__index = mt
+mt.__name = 'glob'
+
+function mt:addPattern(pat)
+ if type(pat) ~= 'string' then
+ return
+ end
+ self.pattern[#self.pattern+1] = pat
+ if self.options.ignoreCase then
+ pat = pat:lower()
+ end
+ local states, err = parser:match(pat)
+ if not states then
+ self.errors[#self.errors+1] = {
+ pattern = pat,
+ message = err
+ }
+ return
+ end
+ for _, state in ipairs(states) do
+ if state.neg then
+ self.refused[#self.refused+1] = matcher(state)
+ else
+ self.passed[#self.passed+1] = matcher(state)
+ end
+ end
+end
+
+function mt:setOption(op, val)
+ if val == nil then
+ val = true
+ end
+ self.options[op] = val
+end
+
+function mt:__call(path)
+ if self.options.ignoreCase then
+ path = path:lower()
+ end
+ for _, refused in ipairs(self.refused) do
+ if refused(path) then
+ return false
+ end
+ end
+ for _, passed in ipairs(self.passed) do
+ if passed(path) then
+ return true
+ end
+ end
+ return false
+end
+
+return function (pattern, options)
+ local self = setmetatable({
+ pattern = {},
+ options = {},
+ passed = {},
+ refused = {},
+ errors = {},
+ }, mt)
+
+ if type(pattern) == 'table' then
+ for _, pat in ipairs(pattern) do
+ self:addPattern(pat)
+ end
+ else
+ self:addPattern(pattern)
+ end
+
+ if type(options) == 'table' then
+ for op, val in pairs(options) do
+ self:setOption(op, val)
+ end
+ end
+ return self
+end
diff --git a/script/glob/init.lua b/script/glob/init.lua
new file mode 100644
index 00000000..6578a0d4
--- /dev/null
+++ b/script/glob/init.lua
@@ -0,0 +1,4 @@
+return {
+ glob = require 'glob.glob',
+ gitignore = require 'glob.gitignore',
+}
diff --git a/script/glob/matcher.lua b/script/glob/matcher.lua
new file mode 100644
index 00000000..f4c2b12c
--- /dev/null
+++ b/script/glob/matcher.lua
@@ -0,0 +1,151 @@
+local m = require 'lpeglabel'
+
+local Slash = m.S('/\\')^1
+local Symbol = m.S',{}[]*?/\\'
+local Char = 1 - Symbol
+local Path = Char^1 * Slash
+local NoWord = #(m.P(-1) + Symbol)
+local function whatHappened()
+ return m.Cmt(m.P(1)^1, function (...)
+ print(...)
+ end)
+end
+
+local mt = {}
+mt.__index = mt
+mt.__name = 'matcher'
+
+function mt:exp(state, index)
+ local exp = state[index]
+ if not exp then
+ return
+ end
+ if exp.type == 'word' then
+ return self:word(exp, state, index + 1)
+ elseif exp.type == 'char' then
+ return self:char(exp, state, index + 1)
+ elseif exp.type == '**' then
+ return self:anyPath(exp, state, index + 1)
+ elseif exp.type == '*' then
+ return self:anyChar(exp, state, index + 1)
+ elseif exp.type == '?' then
+ return self:oneChar(exp, state, index + 1)
+ elseif exp.type == '[]' then
+ return self:range(exp, state, index + 1)
+ elseif exp.type == '/' then
+ return self:slash(exp, state, index + 1)
+ end
+end
+
+function mt:word(exp, state, index)
+ local current = self:exp(exp.value, 1)
+ local after = self:exp(state, index)
+ if after then
+ return current * Slash * after
+ else
+ return current
+ end
+end
+
+function mt:char(exp, state, index)
+ local current = m.P(exp.value)
+ local after = self:exp(state, index)
+ if after then
+ return current * after * NoWord
+ else
+ return current * NoWord
+ end
+end
+
+function mt:anyPath(_, state, index)
+ local after = self:exp(state, index)
+ if after then
+ return m.P {
+ 'Main',
+ Main = after
+ + Path * m.V'Main'
+ }
+ else
+ return Path^0
+ end
+end
+
+function mt:anyChar(_, state, index)
+ local after = self:exp(state, index)
+ if after then
+ return m.P {
+ 'Main',
+ Main = after
+ + Char * m.V'Main'
+ }
+ else
+ return Char^0
+ end
+end
+
+function mt:oneChar(_, state, index)
+ local after = self:exp(state, index)
+ if after then
+ return Char * after
+ else
+ return Char
+ end
+end
+
+function mt:range(exp, state, index)
+ local after = self:exp(state, index)
+ local ranges = {}
+ local selects = {}
+ for _, range in ipairs(exp.value) do
+ if #range == 1 then
+ selects[#selects+1] = range[1]
+ elseif #range == 2 then
+ ranges[#ranges+1] = range[1] .. range[2]
+ end
+ end
+ local current = m.S(table.concat(selects)) + m.R(table.unpack(ranges))
+ if after then
+ return current * after
+ else
+ return current
+ end
+end
+
+function mt:slash(_, state, index)
+ local after = self:exp(state, index)
+ if after then
+ return after
+ else
+ self.needDirectory = true
+ return nil
+ end
+end
+
+function mt:pattern(state)
+ if state.root then
+ return m.C(self:exp(state, 1))
+ else
+ return m.C(self:anyPath(nil, state, 1))
+ end
+end
+
+function mt:isNeedDirectory()
+ return self.needDirectory == true
+end
+
+function mt:isNegative()
+ return self.state.neg == true
+end
+
+function mt:__call(path)
+ return self.matcher:match(path)
+end
+
+return function (state, options)
+ local self = setmetatable({
+ options = options,
+ state = state,
+ }, mt)
+ self.matcher = self:pattern(state)
+ return self
+end
diff --git a/script/json-beautify.lua b/script/json-beautify.lua
new file mode 100644
index 00000000..1d2a6cc0
--- /dev/null
+++ b/script/json-beautify.lua
@@ -0,0 +1,120 @@
+local json = require "json"
+local type = type
+local next = next
+local error = error
+local table_concat = table.concat
+local table_sort = table.sort
+local string_rep = string.rep
+local math_type = math.type
+local setmetatable = setmetatable
+local getmetatable = getmetatable
+
+local statusMark
+local statusQue
+local statusDep
+local statusOpt
+
+local defaultOpt = {
+ newline = "\n",
+ indent = " ",
+}
+defaultOpt.__index = defaultOpt
+
+local function encode_newline()
+ statusQue[#statusQue+1] = statusOpt.newline..string_rep(statusOpt.indent, statusDep)
+end
+
+local encode_map = {}
+for k ,v in next, json.encode_map do
+ encode_map[k] = v
+end
+
+local encode_string = json.encode_map.string
+
+local function encode(v)
+ local res = encode_map[type(v)](v)
+ statusQue[#statusQue+1] = res
+end
+
+function encode_map.table(t)
+ local first_val = next(t)
+ if first_val == nil then
+ if getmetatable(t) == json.object then
+ return "{}"
+ else
+ return "[]"
+ end
+ end
+ if statusMark[t] then
+ error("circular reference")
+ end
+ statusMark[t] = true
+ if type(first_val) == 'string' then
+ local key = {}
+ for k in next, t do
+ if type(k) ~= "string" then
+ error("invalid table: mixed or invalid key types")
+ end
+ key[#key+1] = k
+ end
+ table_sort(key)
+ statusQue[#statusQue+1] = "{"
+ statusDep = statusDep + 1
+ encode_newline()
+ local k = key[1]
+ statusQue[#statusQue+1] = encode_string(k)
+ statusQue[#statusQue+1] = ": "
+ encode(t[k])
+ for i = 2, #key do
+ local k = key[i]
+ statusQue[#statusQue+1] = ","
+ encode_newline()
+ statusQue[#statusQue+1] = encode_string(k)
+ statusQue[#statusQue+1] = ": "
+ encode(t[k])
+ end
+ statusDep = statusDep - 1
+ encode_newline()
+ statusMark[t] = nil
+ return "}"
+ else
+ local max = 0
+ for k in next, t do
+ if math_type(k) ~= "integer" or k <= 0 then
+ error("invalid table: mixed or invalid key types")
+ end
+ if max < k then
+ max = k
+ end
+ end
+ statusQue[#statusQue+1] = "["
+ statusDep = statusDep + 1
+ encode_newline()
+ encode(t[1])
+ for i = 2, max do
+ statusQue[#statusQue+1] = ","
+ encode_newline()
+ encode(t[i])
+ end
+ statusDep = statusDep - 1
+ encode_newline()
+ statusMark[t] = nil
+ return "]"
+ end
+end
+
+local function beautify(v, option)
+ if type(v) == "string" then
+ v = json.decode(v)
+ end
+ statusMark = {}
+ statusQue = {}
+ statusDep = 0
+ statusOpt = option and setmetatable(option, defaultOpt) or defaultOpt
+ encode(v)
+ return table_concat(statusQue)
+end
+
+json.beautify = beautify
+
+return json
diff --git a/script/json.lua b/script/json.lua
new file mode 100644
index 00000000..46261d7d
--- /dev/null
+++ b/script/json.lua
@@ -0,0 +1,450 @@
+local type = type
+local next = next
+local error = error
+local tonumber = tonumber
+local tostring = tostring
+local utf8_char = utf8.char
+local table_concat = table.concat
+local table_sort = table.sort
+local string_char = string.char
+local string_byte = string.byte
+local string_find = string.find
+local string_match = string.match
+local string_gsub = string.gsub
+local string_sub = string.sub
+local string_format = string.format
+local math_type = math.type
+local setmetatable = setmetatable
+local getmetatable = getmetatable
+local Inf = math.huge
+
+local json = {}
+json.object = {}
+
+-- json.encode --
+local statusMark
+local statusQue
+
+local encode_map = {}
+
+local encode_escape_map = {
+ [ "\"" ] = "\\\"",
+ [ "\\" ] = "\\\\",
+ [ "/" ] = "\\/",
+ [ "\b" ] = "\\b",
+ [ "\f" ] = "\\f",
+ [ "\n" ] = "\\n",
+ [ "\r" ] = "\\r",
+ [ "\t" ] = "\\t",
+}
+
+local decode_escape_set = {}
+local decode_escape_map = {}
+for k, v in next, encode_escape_map do
+ decode_escape_map[v] = k
+ decode_escape_set[string_byte(v, 2)] = true
+end
+
+for i = 0, 31 do
+ local c = string_char(i)
+ if not encode_escape_map[c] then
+ encode_escape_map[c] = string_format("\\u%04x", i)
+ end
+end
+
+local function encode(v)
+ local res = encode_map[type(v)](v)
+ statusQue[#statusQue+1] = res
+end
+
+encode_map["nil"] = function ()
+ return "null"
+end
+
+function encode_map.string(v)
+ return '"' .. string_gsub(v, '[\0-\31\\"]', encode_escape_map) .. '"'
+end
+local encode_string = encode_map.string
+
+local function convertreal(v)
+ local g = string_format('%.16g', v)
+ if tonumber(g) == v then
+ return g
+ end
+ return string_format('%.17g', v)
+end
+
+function encode_map.number(v)
+ if v ~= v or v <= -Inf or v >= Inf then
+ error("unexpected number value '" .. tostring(v) .. "'")
+ end
+ return string_gsub(convertreal(v), ',', '.')
+end
+
+function encode_map.boolean(v)
+ if v then
+ return "true"
+ else
+ return "false"
+ end
+end
+
+function encode_map.table(t)
+ local first_val = next(t)
+ if first_val == nil then
+ if getmetatable(t) == json.object then
+ return "{}"
+ else
+ return "[]"
+ end
+ end
+ if statusMark[t] then
+ error("circular reference")
+ end
+ statusMark[t] = true
+ if type(first_val) == 'string' then
+ local key = {}
+ for k in next, t do
+ if type(k) ~= "string" then
+ error("invalid table: mixed or invalid key types")
+ end
+ key[#key+1] = k
+ end
+ table_sort(key)
+ statusQue[#statusQue+1] = "{"
+ local k = key[1]
+ statusQue[#statusQue+1] = encode_string(k)
+ statusQue[#statusQue+1] = ":"
+ encode(t[k])
+ for i = 2, #key do
+ local k = key[i]
+ statusQue[#statusQue+1] = ","
+ statusQue[#statusQue+1] = encode_string(k)
+ statusQue[#statusQue+1] = ":"
+ encode(t[k])
+ end
+ statusMark[t] = nil
+ return "}"
+ else
+ local max = 0
+ for k in next, t do
+ if math_type(k) ~= "integer" or k <= 0 then
+ error("invalid table: mixed or invalid key types")
+ end
+ if max < k then
+ max = k
+ end
+ end
+ statusQue[#statusQue+1] = "["
+ encode(t[1])
+ for i = 2, max do
+ statusQue[#statusQue+1] = ","
+ encode(t[i])
+ end
+ statusMark[t] = nil
+ return "]"
+ end
+end
+
+local function encode_unexpected(v)
+ if v == json.null then
+ return "null"
+ else
+ error("unexpected type '"..type(v).."'")
+ end
+end
+encode_map[ "function" ] = encode_unexpected
+encode_map[ "userdata" ] = encode_unexpected
+encode_map[ "thread" ] = encode_unexpected
+
+function json.encode(v)
+ statusMark = {}
+ statusQue = {}
+ encode(v)
+ return table_concat(statusQue)
+end
+
+json.encode_map = encode_map
+
+-- json.decode --
+
+local statusBuf
+local statusPos
+local statusTop
+local statusAry = {}
+local statusRef = {}
+
+local function find_line()
+ local line = 1
+ local pos = 1
+ while true do
+ local f, _, nl1, nl2 = string_find(statusBuf, '([\n\r])([\n\r]?)', pos)
+ if not f then
+ return line, statusPos - pos + 1
+ end
+ local newpos = f + ((nl1 == nl2 or nl2 == '') and 1 or 2)
+ if newpos > statusPos then
+ return line, statusPos - pos + 1
+ end
+ pos = newpos
+ line = line + 1
+ end
+end
+
+local function decode_error(msg)
+ error(string_format("ERROR: %s at line %d col %d", msg, find_line()))
+end
+
+local function get_word()
+ return string_match(statusBuf, "^[^ \t\r\n%]},]*", statusPos)
+end
+
+local function next_byte()
+ statusPos = string_find(statusBuf, "[^ \t\r\n]", statusPos)
+ if statusPos then
+ return string_byte(statusBuf, statusPos)
+ end
+ statusPos = #statusBuf + 1
+ decode_error("unexpected character '<eol>'")
+end
+
+local function expect_byte(c)
+ local _, pos = string_find(statusBuf, c, statusPos)
+ if not pos then
+ decode_error(string_format("expected '%s'", string_sub(c, #c)))
+ end
+ statusPos = pos
+end
+
+local function decode_unicode_surrogate(s1, s2)
+ return utf8_char(0x10000 + (tonumber(s1, 16) - 0xd800) * 0x400 + (tonumber(s2, 16) - 0xdc00))
+end
+
+local function decode_unicode_escape(s)
+ return utf8_char(tonumber(s, 16))
+end
+
+local function decode_string()
+ local has_unicode_escape = false
+ local has_escape = false
+ local i = statusPos + 1
+ while true do
+ i = string_find(statusBuf, '["\\\0-\31]', i)
+ if not i then
+ decode_error "expected closing quote for string"
+ end
+ local x = string_byte(statusBuf, i)
+ if x < 32 then
+ statusPos = i
+ decode_error "control character in string"
+ end
+ if x == 34 --[[ '"' ]] then
+ local s = string_sub(statusBuf, statusPos + 1, i - 1)
+ if has_unicode_escape then
+ s = string_gsub(string_gsub(s
+ , "\\u([dD][89aAbB]%x%x)\\u([dD][c-fC-F]%x%x)", decode_unicode_surrogate)
+ , "\\u(%x%x%x%x)", decode_unicode_escape)
+ end
+ if has_escape then
+ s = string_gsub(s, "\\.", decode_escape_map)
+ end
+ statusPos = i + 1
+ return s
+ end
+ --assert(x == 92 --[[ "\\" ]])
+ local nx = string_byte(statusBuf, i+1)
+ if nx == 117 --[[ "u" ]] then
+ if not string_match(statusBuf, "^%x%x%x%x", i+2) then
+ statusPos = i
+ decode_error "invalid unicode escape in string"
+ end
+ has_unicode_escape = true
+ i = i + 6
+ else
+ if not decode_escape_set[nx] then
+ statusPos = i
+ decode_error("invalid escape char '" .. (nx and string_char(nx) or "<eol>") .. "' in string")
+ end
+ has_escape = true
+ i = i + 2
+ end
+ end
+end
+
+local function decode_number()
+ local word = get_word()
+ if not (
+ string_match(word, '^.[0-9]*$')
+ or string_match(word, '^.[0-9]*%.[0-9]+$')
+ or string_match(word, '^.[0-9]*[Ee][+-]?[0-9]+$')
+ or string_match(word, '^.[0-9]*%.[0-9]+[Ee][+-]?[0-9]+$')
+ ) then
+ decode_error("invalid number '" .. word .. "'")
+ end
+ statusPos = statusPos + #word
+ return tonumber(word)
+end
+
+local function decode_number_negative()
+ local word = get_word()
+ if not (
+ string_match(word, '^.[1-9][0-9]*$')
+ or string_match(word, '^.[1-9][0-9]*%.[0-9]+$')
+ or string_match(word, '^.[1-9][0-9]*[Ee][+-]?[0-9]+$')
+ or string_match(word, '^.[1-9][0-9]*%.[0-9]+[Ee][+-]?[0-9]+$')
+ or word == "-0"
+ or string_match(word, '^.0%.[0-9]+$')
+ or string_match(word, '^.0[Ee][+-]?[0-9]+$')
+ or string_match(word, '^.0%.[0-9]+[Ee][+-]?[0-9]+$')
+ ) then
+ decode_error("invalid number '" .. word .. "'")
+ end
+ statusPos = statusPos + #word
+ return tonumber(word)
+end
+
+local function decode_number_zero()
+ local word = get_word()
+ if not (
+ #word == 1
+ or string_match(word, '^.%.[0-9]+$')
+ or string_match(word, '^.[Ee][+-]?[0-9]+$')
+ or string_match(word, '^.%.[0-9]+[Ee][+-]?[0-9]+$')
+ ) then
+ decode_error("invalid number '" .. word .. "'")
+ end
+ statusPos = statusPos + #word
+ return tonumber(word)
+end
+
+local function decode_true()
+ if string_sub(statusBuf, statusPos, statusPos+3) ~= "true" then
+ decode_error("invalid literal '" .. get_word() .. "'")
+ end
+ statusPos = statusPos + 4
+ return true
+end
+
+local function decode_false()
+ if string_sub(statusBuf, statusPos, statusPos+4) ~= "false" then
+ decode_error("invalid literal '" .. get_word() .. "'")
+ end
+ statusPos = statusPos + 5
+ return false
+end
+
+local function decode_null()
+ if string_sub(statusBuf, statusPos, statusPos+3) ~= "null" then
+ decode_error("invalid literal '" .. get_word() .. "'")
+ end
+ statusPos = statusPos + 4
+ return json.null
+end
+
+local function decode_array()
+ statusPos = statusPos + 1
+ local res = {}
+ if next_byte() == 93 --[[ "]" ]] then
+ statusPos = statusPos + 1
+ return res
+ end
+ statusTop = statusTop + 1
+ statusAry[statusTop] = true
+ statusRef[statusTop] = res
+ return res
+end
+
+local function decode_object()
+ statusPos = statusPos + 1
+ local res = {}
+ if next_byte() == 125 --[[ "}" ]] then
+ statusPos = statusPos + 1
+ return setmetatable(res, json.object)
+ end
+ statusTop = statusTop + 1
+ statusAry[statusTop] = false
+ statusRef[statusTop] = res
+ return res
+end
+
+local decode_uncompleted_map = {
+ [ string_byte '"' ] = decode_string,
+ [ string_byte "0" ] = decode_number_zero,
+ [ string_byte "1" ] = decode_number,
+ [ string_byte "2" ] = decode_number,
+ [ string_byte "3" ] = decode_number,
+ [ string_byte "4" ] = decode_number,
+ [ string_byte "5" ] = decode_number,
+ [ string_byte "6" ] = decode_number,
+ [ string_byte "7" ] = decode_number,
+ [ string_byte "8" ] = decode_number,
+ [ string_byte "9" ] = decode_number,
+ [ string_byte "-" ] = decode_number_negative,
+ [ string_byte "t" ] = decode_true,
+ [ string_byte "f" ] = decode_false,
+ [ string_byte "n" ] = decode_null,
+ [ string_byte "[" ] = decode_array,
+ [ string_byte "{" ] = decode_object,
+}
+local function unexpected_character()
+ decode_error("unexpected character '" .. string_sub(statusBuf, statusPos, statusPos) .. "'")
+end
+
+local decode_map = {}
+for i = 0, 255 do
+ decode_map[i] = decode_uncompleted_map[i] or unexpected_character
+end
+
+local function decode()
+ return decode_map[next_byte()]()
+end
+
+local function decode_item()
+ local top = statusTop
+ local ref = statusRef[top]
+ if statusAry[top] then
+ ref[#ref+1] = decode()
+ else
+ expect_byte '^[ \t\r\n]*"'
+ local key = decode_string()
+ expect_byte '^[ \t\r\n]*:'
+ statusPos = statusPos + 1
+ ref[key] = decode()
+ end
+ if top == statusTop then
+ repeat
+ local chr = next_byte(); statusPos = statusPos + 1
+ if chr == 44 --[[ "," ]] then
+ return
+ end
+ if statusAry[statusTop] then
+ if chr ~= 93 --[[ "]" ]] then decode_error "expected ']' or ','" end
+ else
+ if chr ~= 125 --[[ "}" ]] then decode_error "expected '}' or ','" end
+ end
+ statusTop = statusTop - 1
+ until statusTop == 0
+ end
+end
+
+function json.decode(str)
+ if type(str) ~= "string" then
+ error("expected argument of type string, got " .. type(str))
+ end
+ statusBuf = str
+ statusPos = 1
+ statusTop = 0
+ local res = decode()
+ while statusTop > 0 do
+ decode_item()
+ end
+ if string_find(statusBuf, "[^ \t\r\n]", statusPos) then
+ decode_error "trailing garbage"
+ end
+ return res
+end
+
+-- Generate a lightuserdata
+json.null = debug.upvalueid(decode, 1)
+
+return json
diff --git a/script/jsonrpc.lua b/script/jsonrpc.lua
new file mode 100644
index 00000000..17e6e73d
--- /dev/null
+++ b/script/jsonrpc.lua
@@ -0,0 +1,64 @@
+local json = require 'json'
+local pcall = pcall
+local tonumber = tonumber
+local util = require 'utility'
+local log = require 'brave.log'
+
+---@class jsonrpc
+local m = {}
+m.type = 'jsonrpc'
+
+function m.encode(pack)
+ pack.jsonrpc = '2.0'
+ local content = json.encode(pack)
+ local buf = ('Content-Length: %d\r\n\r\n%s'):format(#content, content)
+ return buf
+end
+
+local function readProtoHead(reader)
+ local head = {}
+ while true do
+ local line = reader 'L'
+ if line == nil then
+ -- 说明管道已经关闭了
+ return nil, 'Disconnected!'
+ end
+ if line == '\r\n' then
+ break
+ end
+ local k, v = line:match '^([^:]+)%s*%:%s*(.+)\r\n$'
+ if not k then
+ return nil, 'Proto header error: ' .. line
+ end
+ if k == 'Content-Length' then
+ v = tonumber(v)
+ end
+ head[k] = v
+ end
+ return head
+end
+
+function m.decode(reader, errHandle)
+ local head, err = readProtoHead(reader, errHandle)
+ if not head then
+ return nil, err
+ end
+ local len = head['Content-Length']
+ if not len then
+ return nil, 'Proto header error: ' .. util.dump(head)
+ end
+ local content = reader(len)
+ if not content then
+ return nil, 'Proto read error'
+ end
+ local null = json.null
+ json.null = nil
+ local suc, res = pcall(json.decode, content)
+ json.null = null
+ if not suc then
+ return nil, 'Proto parse error: ' .. res
+ end
+ return res
+end
+
+return m
diff --git a/script/language.lua b/script/language.lua
new file mode 100644
index 00000000..6077e4c6
--- /dev/null
+++ b/script/language.lua
@@ -0,0 +1,140 @@
+local fs = require 'bee.filesystem'
+local lni = require 'lni'
+local util = require 'utility'
+
+local function supportLanguage()
+ local list = {}
+ for path in (ROOT / 'locale'):list_directory() do
+ if fs.is_directory(path) then
+ list[#list+1] = path:filename():string():lower()
+ end
+ end
+ return list
+end
+
+local function osLanguage()
+ return LANG:lower()
+end
+
+local function getLanguage(id)
+ local support = supportLanguage()
+ -- 检查是否支持语言
+ if support[id] then
+ return id
+ end
+ -- 根据语言的前2个字母来找近似语言
+ for _, lang in ipairs(support) do
+ if lang:sub(1, 2) == id:sub(1, 2) then
+ return lang
+ end
+ end
+ -- 使用英文
+ return 'enUS'
+end
+
+local function loadFileByLanguage(name, language)
+ local path = ROOT / 'locale' / language / (name .. '.lni')
+ local buf = util.loadFile(path:string())
+ if not buf then
+ return {}
+ end
+ local suc, tbl = xpcall(lni, log.error, buf, path:string())
+ if not suc then
+ return {}
+ end
+ return tbl
+end
+
+local function formatAsArray(str, ...)
+ local index = 0
+ local args = {...}
+ return str:gsub('%{(.-)%}', function (pat)
+ local id, fmt
+ local pos = pat:find(':', 1, true)
+ if pos then
+ id = pat:sub(1, pos-1)
+ fmt = pat:sub(pos+1)
+ else
+ id = pat
+ fmt = 's'
+ end
+ id = tonumber(id)
+ if not id then
+ index = index + 1
+ id = index
+ end
+ return ('%'..fmt):format(args[id])
+ end)
+end
+
+local function formatAsTable(str, ...)
+ local args = ...
+ return str:gsub('%{(.-)%}', function (pat)
+ local id, fmt
+ local pos = pat:find(':', 1, true)
+ if pos then
+ id = pat:sub(1, pos-1)
+ fmt = pat:sub(pos+1)
+ else
+ id = pat
+ fmt = 's'
+ end
+ if not id then
+ return
+ end
+ return ('%'..fmt):format(args[id])
+ end)
+end
+
+local function loadLang(name, language)
+ local tbl = loadFileByLanguage(name, 'en-US')
+ if language ~= 'en-US' then
+ local other = loadFileByLanguage(name, language)
+ for k, v in pairs(other) do
+ tbl[k] = v
+ end
+ end
+ return setmetatable(tbl, {
+ __index = function (self, key)
+ self[key] = key
+ return key
+ end,
+ __call = function (self, key, ...)
+ local str = self[key]
+ if not ... then
+ return str
+ end
+ local suc, res
+ if type(...) == 'table' then
+ suc, res = pcall(formatAsTable, str, ...)
+ else
+ suc, res = pcall(formatAsArray, str, ...)
+ end
+ if suc then
+ return res
+ else
+ -- 这里不能使用翻译,以免死循环
+ log.warn(('[%s][%s-%s] formated error: %s'):format(
+ language, name, key, str
+ ))
+ return str
+ end
+ end,
+ })
+end
+
+local function init()
+ local id = osLanguage()
+ local language = getLanguage(id)
+ log.info(('VSC language: %s'):format(id))
+ log.info(('LS language: %s'):format(language))
+ return setmetatable({ id = language }, {
+ __index = function (self, name)
+ local tbl = loadLang(name, language)
+ self[name] = tbl
+ return tbl
+ end,
+ })
+end
+
+return init()
diff --git a/script/library.lua b/script/library.lua
new file mode 100644
index 00000000..5a48499b
--- /dev/null
+++ b/script/library.lua
@@ -0,0 +1,205 @@
+local lni = require 'lni'
+local fs = require 'bee.filesystem'
+local config = require 'config'
+local util = require 'utility'
+local lang = require 'language'
+local client = require 'provider.client'
+
+local m = {}
+
+local function getDocFormater()
+ local version = config.config.runtime.version
+ if client.client() == 'vscode' then
+ if version == 'Lua 5.1' then
+ return 'HOVER_NATIVE_DOCUMENT_LUA51'
+ elseif version == 'Lua 5.2' then
+ return 'HOVER_NATIVE_DOCUMENT_LUA52'
+ elseif version == 'Lua 5.3' then
+ return 'HOVER_NATIVE_DOCUMENT_LUA53'
+ elseif version == 'Lua 5.4' then
+ return 'HOVER_NATIVE_DOCUMENT_LUA54'
+ elseif version == 'LuaJIT' then
+ return 'HOVER_NATIVE_DOCUMENT_LUAJIT'
+ end
+ else
+ if version == 'Lua 5.1' then
+ return 'HOVER_DOCUMENT_LUA51'
+ elseif version == 'Lua 5.2' then
+ return 'HOVER_DOCUMENT_LUA52'
+ elseif version == 'Lua 5.3' then
+ return 'HOVER_DOCUMENT_LUA53'
+ elseif version == 'Lua 5.4' then
+ return 'HOVER_DOCUMENT_LUA54'
+ elseif version == 'LuaJIT' then
+ return 'HOVER_DOCUMENT_LUAJIT'
+ end
+ end
+end
+
+local function convertLink(text)
+ local fmt = getDocFormater()
+ return text:gsub('%$([%.%w]+)', function (name)
+ if fmt then
+ return ('[%s](%s)'):format(name, lang.script(fmt, 'pdf-' .. name))
+ else
+ return ('`%s`'):format(name)
+ end
+ end):gsub('§([%.%w]+)', function (name)
+ if fmt then
+ return ('[§%s](%s)'):format(name, lang.script(fmt, name))
+ else
+ return ('`%s`'):format(name)
+ end
+ end)
+end
+
+local function createViewDocument(name)
+ local fmt = getDocFormater()
+ if not fmt then
+ return nil
+ end
+ return ('[%s](%s)'):format(lang.script.HOVER_VIEW_DOCUMENTS, lang.script(fmt, 'pdf-' .. name))
+end
+
+local function compileSingleMetaDoc(script, metaLang)
+ local middleBuf = {}
+ local compileBuf = {}
+
+ local last = 1
+ for start, lua, finish in script:gmatch '()%-%-%-%#([^\n\r]*)()' do
+ middleBuf[#middleBuf+1] = ('PUSH [===[%s]===]'):format(script:sub(last, start - 1))
+ middleBuf[#middleBuf+1] = lua
+ last = finish
+ end
+ middleBuf[#middleBuf+1] = ('PUSH [===[%s]===]'):format(script:sub(last))
+ local middleScript = table.concat(middleBuf, '\n')
+ local version, jit
+ if config.config.runtime.version == 'LuaJIT' then
+ version = 5.1
+ jit = true
+ else
+ version = tonumber(config.config.runtime.version:sub(-3))
+ jit = false
+ end
+
+ local env = setmetatable({
+ VERSION = version,
+ JIT = jit,
+ PUSH = function (text)
+ compileBuf[#compileBuf+1] = text
+ end,
+ DES = function (name)
+ local des = metaLang[name]
+ if not des then
+ des = ('Miss locale <%s>'):format(name)
+ end
+ compileBuf[#compileBuf+1] = '---\n'
+ for line in util.eachLine(des) do
+ compileBuf[#compileBuf+1] = '---'
+ compileBuf[#compileBuf+1] = convertLink(line)
+ compileBuf[#compileBuf+1] = '\n'
+ end
+ local viewDocument = createViewDocument(name)
+ if viewDocument then
+ compileBuf[#compileBuf+1] = '---\n---'
+ compileBuf[#compileBuf+1] = viewDocument
+ compileBuf[#compileBuf+1] = '\n'
+ end
+ compileBuf[#compileBuf+1] = '---\n'
+ end,
+ DESENUM = function (name)
+ local des = metaLang[name]
+ if not des then
+ des = ('Miss locale <%s>'):format(name)
+ end
+ compileBuf[#compileBuf+1] = convertLink(des)
+ compileBuf[#compileBuf+1] = '\n'
+ end,
+ ALIVE = function (str)
+ local isAlive
+ for piece in str:gmatch '[^%,]+' do
+ if piece:sub(1, 1) == '>' then
+ local alive = tonumber(piece:sub(2))
+ if not alive or version >= alive then
+ isAlive = true
+ break
+ end
+ elseif piece:sub(1, 1) == '<' then
+ local alive = tonumber(piece:sub(2))
+ if not alive or version <= alive then
+ isAlive = true
+ break
+ end
+ else
+ local alive = tonumber(piece)
+ if not alive or version == alive then
+ isAlive = true
+ break
+ end
+ end
+ end
+ if not isAlive then
+ compileBuf[#compileBuf+1] = '---@deprecated\n'
+ end
+ end,
+ }, { __index = _ENV })
+
+ util.saveFile((ROOT / 'log' / 'middleScript.lua'):string(), middleScript)
+
+ assert(load(middleScript, middleScript, 't', env))()
+ return table.concat(compileBuf)
+end
+
+local function loadMetaLocale(langID, result)
+ result = result or {}
+ local path = (ROOT / 'locale' / langID / 'meta.lni'):string()
+ local lniContent = util.loadFile(path)
+ if lniContent then
+ xpcall(lni, log.error, lniContent, path, {result})
+ end
+ return result
+end
+
+local function compileMetaDoc()
+ local langID = lang.id
+ local version = config.config.runtime.version
+ local metapath = ROOT / 'meta' / config.config.runtime.meta:gsub('%$%{(.-)%}', {
+ version = version,
+ language = langID,
+ })
+ if fs.exists(metapath) then
+ --return
+ end
+
+ local metaLang = loadMetaLocale('en-US')
+ if langID ~= 'en-US' then
+ loadMetaLocale(langID, metaLang)
+ end
+ --log.debug('metaLang:', util.dump(metaLang))
+
+ m.metaPath = metapath:string()
+ m.metaPaths = {}
+ fs.create_directory(metapath)
+ local templateDir = ROOT / 'meta' / 'template'
+ for fullpath in templateDir:list_directory() do
+ local filename = fullpath:filename()
+ local metaDoc = compileSingleMetaDoc(util.loadFile(fullpath:string()), metaLang)
+ local filepath = metapath / filename
+ util.saveFile(filepath:string(), metaDoc)
+ m.metaPaths[#m.metaPaths+1] = filepath:string()
+ end
+end
+
+local function initFromMetaDoc()
+ compileMetaDoc()
+end
+
+local function init()
+ initFromMetaDoc()
+end
+
+function m.init()
+ init()
+end
+
+return m
diff --git a/script/log.lua b/script/log.lua
new file mode 100644
index 00000000..169fed7f
--- /dev/null
+++ b/script/log.lua
@@ -0,0 +1,142 @@
+local fs = require 'bee.filesystem'
+
+local osTime = os.time
+local osClock = os.clock
+local osDate = os.date
+local ioOpen = io.open
+local tablePack = table.pack
+local tableConcat = table.concat
+local tostring = tostring
+local debugTraceBack = debug.traceback
+local mathModf = math.modf
+local debugGetInfo = debug.getinfo
+local ioStdErr = io.stderr
+
+_ENV = nil
+
+local m = {}
+
+m.file = nil
+m.startTime = osTime() - osClock()
+m.size = 0
+m.maxSize = 100 * 1024 * 1024
+
+local function trimSrc(src)
+ src = src:sub(m.prefixLen + 3, -5)
+ src = src:gsub('^[/\\]+', '')
+ src = src:gsub('[\\/]+', '.')
+ return src
+end
+
+local function init_log_file()
+ if not m.file then
+ m.file = ioOpen(m.path, 'w')
+ if not m.file then
+ return
+ end
+ m.file:write('')
+ m.file:close()
+ m.file = ioOpen(m.path, 'ab')
+ if not m.file then
+ return
+ end
+ m.file:setvbuf 'no'
+ end
+end
+
+local function pushLog(level, ...)
+ if not m.path then
+ return
+ end
+ local t = tablePack(...)
+ for i = 1, t.n do
+ t[i] = tostring(t[i])
+ end
+ local str = tableConcat(t, '\t', 1, t.n)
+ if level == 'error' then
+ str = str .. '\n' .. debugTraceBack(nil, 3)
+ end
+ local info = debugGetInfo(3, 'Sl')
+ return m.raw(0, level, str, info.source, info.currentline, osClock())
+end
+
+function m.info(...)
+ pushLog('info', ...)
+end
+
+function m.debug(...)
+ pushLog('debug', ...)
+end
+
+function m.trace(...)
+ pushLog('trace', ...)
+end
+
+function m.warn(...)
+ pushLog('warn', ...)
+end
+
+function m.error(...)
+ return pushLog('error', ...)
+end
+
+function m.raw(thd, level, msg, source, currentline, clock)
+ if level == 'error' then
+ ioStdErr:write(msg .. '\n')
+ end
+ if m.size > m.maxSize then
+ return
+ end
+ init_log_file()
+ if not m.file then
+ return ''
+ end
+ local sec, ms = mathModf(m.startTime + clock)
+ local timestr = osDate('%H:%M:%S', sec)
+ local agl = ''
+ if #level < 5 then
+ agl = (' '):rep(5 - #level)
+ end
+ local buf
+ if currentline == -1 then
+ buf = ('[%s.%03.f][%s]%s[#%d]: %s\n'):format(timestr, ms * 1000, level, agl, thd, msg)
+ else
+ buf = ('[%s.%03.f][%s]%s[#%d:%s:%s]: %s\n'):format(timestr, ms * 1000, level, agl, thd, trimSrc(source), currentline, msg)
+ end
+ m.size = m.size + #buf
+ if m.size > m.maxSize then
+ m.file:write(buf:sub(1, m.size - m.maxSize))
+ m.file:write('[REACH MAX SIZE]')
+ else
+ m.file:write(buf)
+ end
+ return buf
+end
+
+function m.init(root, path)
+ local lastBuf
+ if m.file then
+ m.file:close()
+ m.file = nil
+ local file = ioOpen(m.path, 'rb')
+ if file then
+ lastBuf = file:read(m.maxSize)
+ file:close()
+ end
+ end
+ m.path = path:string()
+ m.prefixLen = #root:string()
+ m.size = 0
+ if not fs.exists(path:parent_path()) then
+ fs.create_directories(path:parent_path())
+ end
+ if lastBuf then
+ init_log_file()
+ if m.file then
+ m.file:write(lastBuf)
+ m.size = m.size + #lastBuf
+ end
+ end
+end
+
+return m
diff --git a/script/parser/ast.lua b/script/parser/ast.lua
new file mode 100644
index 00000000..d8614eae
--- /dev/null
+++ b/script/parser/ast.lua
@@ -0,0 +1,1751 @@
+local tonumber = tonumber
+local stringChar = string.char
+local utf8Char = utf8.char
+local tableUnpack = table.unpack
+local mathType = math.type
+local tableRemove = table.remove
+local pairs = pairs
+local tableSort = table.sort
+
+_ENV = nil
+
+local State
+local PushError
+local PushDiag
+local PushComment
+
+-- goto 单独处理
+local RESERVED = {
+ ['and'] = true,
+ ['break'] = true,
+ ['do'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['end'] = true,
+ ['false'] = true,
+ ['for'] = true,
+ ['function'] = true,
+ ['if'] = true,
+ ['in'] = true,
+ ['local'] = true,
+ ['nil'] = true,
+ ['not'] = true,
+ ['or'] = true,
+ ['repeat'] = true,
+ ['return'] = true,
+ ['then'] = true,
+ ['true'] = true,
+ ['until'] = true,
+ ['while'] = true,
+}
+
+local VersionOp = {
+ ['&'] = {'Lua 5.3', 'Lua 5.4'},
+ ['~'] = {'Lua 5.3', 'Lua 5.4'},
+ ['|'] = {'Lua 5.3', 'Lua 5.4'},
+ ['<<'] = {'Lua 5.3', 'Lua 5.4'},
+ ['>>'] = {'Lua 5.3', 'Lua 5.4'},
+ ['//'] = {'Lua 5.3', 'Lua 5.4'},
+}
+
+local function checkOpVersion(op)
+ local versions = VersionOp[op.type]
+ if not versions then
+ return
+ end
+ for i = 1, #versions do
+ if versions[i] == State.version then
+ return
+ end
+ end
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = op.start,
+ finish = op.finish,
+ version = versions,
+ info = {
+ version = State.version,
+ }
+ }
+end
+
+local function checkMissEnd(start)
+ if not State.MissEndErr then
+ return
+ end
+ local err = State.MissEndErr
+ State.MissEndErr = nil
+ local _, finish = State.lua:find('[%w_]+', start)
+ if not finish then
+ return
+ end
+ err.info.related = {
+ {
+ start = start,
+ finish = finish,
+ }
+ }
+ PushError {
+ type = 'MISS_END',
+ start = start,
+ finish = finish,
+ }
+end
+
+local function getSelect(vararg, index)
+ return {
+ type = 'select',
+ start = vararg.start,
+ finish = vararg.finish,
+ vararg = vararg,
+ index = index,
+ }
+end
+
+local function getValue(values, i)
+ if not values then
+ return nil, nil
+ end
+ local value = values[i]
+ if not value then
+ local last = values[#values]
+ if not last then
+ return nil, nil
+ end
+ if last.type == 'call' or last.type == 'varargs' then
+ return getSelect(last, i - #values + 1)
+ end
+ return nil, nil
+ end
+ if value.type == 'call' or value.type == 'varargs' then
+ value = getSelect(value, 1)
+ end
+ return value
+end
+
+local function createLocal(key, effect, value, attrs)
+ if not key then
+ return nil
+ end
+ key.type = 'local'
+ key.effect = effect
+ key.value = value
+ key.attrs = attrs
+ if value then
+ key.range = value.finish
+ end
+ return key
+end
+
+local function createCall(args, start, finish)
+ if args then
+ args.type = 'callargs'
+ args.start = start
+ args.finish = finish
+ end
+ return {
+ type = 'call',
+ start = start,
+ finish = finish,
+ args = args,
+ }
+end
+
+local function packList(start, list, finish)
+ local lastFinish = start
+ local wantName = true
+ local count = 0
+ for i = 1, #list do
+ local ast = list[i]
+ if ast.type == ',' then
+ if wantName or i == #list then
+ PushError {
+ type = 'UNEXPECT_SYMBOL',
+ start = ast.start,
+ finish = ast.finish,
+ info = {
+ symbol = ',',
+ }
+ }
+ end
+ wantName = true
+ else
+ if not wantName then
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = lastFinish,
+ finish = ast.start - 1,
+ info = {
+ symbol = ',',
+ }
+ }
+ end
+ wantName = false
+ count = count + 1
+ list[count] = list[i]
+ end
+ lastFinish = ast.finish + 1
+ end
+ for i = count + 1, #list do
+ list[i] = nil
+ end
+ list.type = 'list'
+ list.start = start
+ list.finish = finish - 1
+ return list
+end
+
+local BinaryLevel = {
+ ['or'] = 1,
+ ['and'] = 2,
+ ['<='] = 3,
+ ['>='] = 3,
+ ['<'] = 3,
+ ['>'] = 3,
+ ['~='] = 3,
+ ['=='] = 3,
+ ['|'] = 4,
+ ['~'] = 5,
+ ['&'] = 6,
+ ['<<'] = 7,
+ ['>>'] = 7,
+ ['..'] = 8,
+ ['+'] = 9,
+ ['-'] = 9,
+ ['*'] = 10,
+ ['//'] = 10,
+ ['/'] = 10,
+ ['%'] = 10,
+ ['^'] = 11,
+}
+
+local BinaryForward = {
+ [01] = true,
+ [02] = true,
+ [03] = true,
+ [04] = true,
+ [05] = true,
+ [06] = true,
+ [07] = true,
+ [08] = false,
+ [09] = true,
+ [10] = true,
+ [11] = false,
+}
+
+local Defs = {
+ Nil = function (pos)
+ return {
+ type = 'nil',
+ start = pos,
+ finish = pos + 2,
+ }
+ end,
+ True = function (pos)
+ return {
+ type = 'boolean',
+ start = pos,
+ finish = pos + 3,
+ [1] = true,
+ }
+ end,
+ False = function (pos)
+ return {
+ type = 'boolean',
+ start = pos,
+ finish = pos + 4,
+ [1] = false,
+ }
+ end,
+ ShortComment = function (start, text, finish)
+ PushComment {
+ start = start,
+ finish = finish - 1,
+ text = text,
+ }
+ end,
+ LongComment = function (beforeEq, afterEq, str, missPos)
+ if missPos then
+ local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']'
+ local s, _, w = str:find('(%][%=]*%])[%c%s]*$')
+ if s then
+ PushError {
+ type = 'ERR_LCOMMENT_END',
+ start = missPos - #str + s - 1,
+ finish = missPos - #str + s + #w - 2,
+ info = {
+ symbol = endSymbol,
+ },
+ fix = {
+ title = 'FIX_LCOMMENT_END',
+ {
+ start = missPos - #str + s - 1,
+ finish = missPos - #str + s + #w - 2,
+ text = endSymbol,
+ }
+ },
+ }
+ end
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = missPos,
+ finish = missPos,
+ info = {
+ symbol = endSymbol,
+ },
+ fix = {
+ title = 'ADD_LCOMMENT_END',
+ {
+ start = missPos,
+ finish = missPos,
+ text = endSymbol,
+ }
+ },
+ }
+ end
+ end,
+ CLongComment = function (start1, finish1, start2, finish2)
+ PushError {
+ type = 'ERR_C_LONG_COMMENT',
+ start = start1,
+ finish = finish2 - 1,
+ fix = {
+ title = 'FIX_C_LONG_COMMENT',
+ {
+ start = start1,
+ finish = finish1 - 1,
+ text = '--[[',
+ },
+ {
+ start = start2,
+ finish = finish2 - 1,
+ text = '--]]'
+ },
+ }
+ }
+ end,
+ CCommentPrefix = function (start, finish)
+ PushError {
+ type = 'ERR_COMMENT_PREFIX',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_COMMENT_PREFIX',
+ {
+ start = start,
+ finish = finish - 1,
+ text = '--',
+ },
+ }
+ }
+ end,
+ String = function (start, quote, str, finish)
+ return {
+ type = 'string',
+ start = start,
+ finish = finish - 1,
+ [1] = str,
+ [2] = quote,
+ }
+ end,
+ LongString = function (beforeEq, afterEq, str, missPos)
+ if missPos then
+ local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']'
+ local s, _, w = str:find('(%][%=]*%])[%c%s]*$')
+ if s then
+ PushError {
+ type = 'ERR_LSTRING_END',
+ start = missPos - #str + s - 1,
+ finish = missPos - #str + s + #w - 2,
+ info = {
+ symbol = endSymbol,
+ },
+ fix = {
+ title = 'FIX_LSTRING_END',
+ {
+ start = missPos - #str + s - 1,
+ finish = missPos - #str + s + #w - 2,
+ text = endSymbol,
+ }
+ },
+ }
+ end
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = missPos,
+ finish = missPos,
+ info = {
+ symbol = endSymbol,
+ },
+ fix = {
+ title = 'ADD_LSTRING_END',
+ {
+ start = missPos,
+ finish = missPos,
+ text = endSymbol,
+ }
+ },
+ }
+ end
+ return '[' .. ('='):rep(afterEq-beforeEq) .. '[', str
+ end,
+ Char10 = function (char)
+ char = tonumber(char)
+ if not char or char < 0 or char > 255 then
+ return ''
+ end
+ return stringChar(char)
+ end,
+ Char16 = function (pos, char)
+ if State.version == 'Lua 5.1' then
+ PushError {
+ type = 'ERR_ESC',
+ start = pos-1,
+ finish = pos,
+ version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return char
+ end
+ return stringChar(tonumber(char, 16))
+ end,
+ CharUtf8 = function (pos, char)
+ if State.version ~= 'Lua 5.3'
+ and State.version ~= 'Lua 5.4'
+ and State.version ~= 'LuaJIT'
+ then
+ PushError {
+ type = 'ERR_ESC',
+ start = pos-3,
+ finish = pos-2,
+ version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return char
+ end
+ if #char == 0 then
+ PushError {
+ type = 'UTF8_SMALL',
+ start = pos-3,
+ finish = pos,
+ }
+ return ''
+ end
+ local v = tonumber(char, 16)
+ if not v then
+ for i = 1, #char do
+ if not tonumber(char:sub(i, i), 16) then
+ PushError {
+ type = 'MUST_X16',
+ start = pos + i - 1,
+ finish = pos + i - 1,
+ }
+ end
+ end
+ return ''
+ end
+ if State.version == 'Lua 5.4' then
+ if v < 0 or v > 0x7FFFFFFF then
+ PushError {
+ type = 'UTF8_MAX',
+ start = pos-3,
+ finish = pos+#char,
+ info = {
+ min = '00000000',
+ max = '7FFFFFFF',
+ }
+ }
+ end
+ else
+ if v < 0 or v > 0x10FFFF then
+ PushError {
+ type = 'UTF8_MAX',
+ start = pos-3,
+ finish = pos+#char,
+ version = v <= 0x7FFFFFFF and 'Lua 5.4' or nil,
+ info = {
+ min = '000000',
+ max = '10FFFF',
+ }
+ }
+ end
+ end
+ if v >= 0 and v <= 0x10FFFF then
+ return utf8Char(v)
+ end
+ return ''
+ end,
+ Number = function (start, number, finish)
+ local n = tonumber(number)
+ if n then
+ State.LastNumber = {
+ type = 'number',
+ start = start,
+ finish = finish - 1,
+ [1] = n,
+ }
+ return State.LastNumber
+ else
+ PushError {
+ type = 'MALFORMED_NUMBER',
+ start = start,
+ finish = finish - 1,
+ }
+ State.LastNumber = {
+ type = 'number',
+ start = start,
+ finish = finish - 1,
+ [1] = 0,
+ }
+ return State.LastNumber
+ end
+ end,
+ FFINumber = function (start, symbol)
+ local lastNumber = State.LastNumber
+ if mathType(lastNumber[1]) == 'float' then
+ PushError {
+ type = 'UNKNOWN_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ info = {
+ symbol = symbol,
+ }
+ }
+ lastNumber[1] = 0
+ return
+ end
+ if State.version ~= 'LuaJIT' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ version = 'LuaJIT',
+ info = {
+ version = State.version,
+ }
+ }
+ lastNumber[1] = 0
+ end
+ end,
+ ImaginaryNumber = function (start, symbol)
+ local lastNumber = State.LastNumber
+ if State.version ~= 'LuaJIT' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ version = 'LuaJIT',
+ info = {
+ version = State.version,
+ }
+ }
+ end
+ lastNumber[1] = 0
+ end,
+ Name = function (start, str, finish)
+ local isKeyWord
+ if RESERVED[str] then
+ isKeyWord = true
+ elseif str == 'goto' then
+ if State.version ~= 'Lua 5.1' and State.version ~= 'LuaJIT' then
+ isKeyWord = true
+ end
+ end
+ if isKeyWord then
+ PushError {
+ type = 'KEYWORD',
+ start = start,
+ finish = finish - 1,
+ }
+ end
+ return {
+ type = 'name',
+ start = start,
+ finish = finish - 1,
+ [1] = str,
+ }
+ end,
+ GetField = function (dot, field)
+ local obj = {
+ type = 'getfield',
+ field = field,
+ dot = dot,
+ start = dot.start,
+ finish = (field or dot).finish,
+ }
+ if field then
+ field.type = 'field'
+ field.parent = obj
+ end
+ return obj
+ end,
+ GetIndex = function (start, index, finish)
+ local obj = {
+ type = 'getindex',
+ start = start,
+ finish = finish - 1,
+ index = index,
+ }
+ if index then
+ index.parent = obj
+ end
+ return obj
+ end,
+ GetMethod = function (colon, method)
+ local obj = {
+ type = 'getmethod',
+ method = method,
+ colon = colon,
+ start = colon.start,
+ finish = (method or colon).finish,
+ }
+ if method then
+ method.type = 'method'
+ method.parent = obj
+ end
+ return obj
+ end,
+ Single = function (unit)
+ unit.type = 'getname'
+ return unit
+ end,
+ Simple = function (units)
+ local last = units[1]
+ for i = 2, #units do
+ local current = units[i]
+ current.node = last
+ current.start = last.start
+ last.next = current
+ last = units[i]
+ end
+ return last
+ end,
+ SimpleCall = function (call)
+ if call.type ~= 'call' and call.type ~= 'getmethod' then
+ PushError {
+ type = 'EXP_IN_ACTION',
+ start = call.start,
+ finish = call.finish,
+ }
+ end
+ return call
+ end,
+ BinaryOp = function (start, op)
+ return {
+ type = op,
+ start = start,
+ finish = start + #op - 1,
+ }
+ end,
+ UnaryOp = function (start, op)
+ return {
+ type = op,
+ start = start,
+ finish = start + #op - 1,
+ }
+ end,
+ Unary = function (first, ...)
+ if not ... then
+ return nil
+ end
+ local list = {first, ...}
+ local e = list[#list]
+ for i = #list - 1, 1, -1 do
+ local op = list[i]
+ checkOpVersion(op)
+ e = {
+ type = 'unary',
+ op = op,
+ start = op.start,
+ finish = e.finish,
+ [1] = e,
+ }
+ end
+ return e
+ end,
+ SubBinary = function (op, symb)
+ if symb then
+ return op, symb
+ end
+ PushError {
+ type = 'MISS_EXP',
+ start = op.start,
+ finish = op.finish,
+ }
+ end,
+ Binary = function (first, op, second, ...)
+ if not first then
+ return second
+ end
+ if not op then
+ return first
+ end
+ if not ... then
+ checkOpVersion(op)
+ return {
+ type = 'binary',
+ op = op,
+ start = first.start,
+ finish = second.finish,
+ [1] = first,
+ [2] = second,
+ }
+ end
+ local list = {first, op, second, ...}
+ local ops = {}
+ for i = 2, #list, 2 do
+ ops[#ops+1] = i
+ end
+ tableSort(ops, function (a, b)
+ local op1 = list[a]
+ local op2 = list[b]
+ local lv1 = BinaryLevel[op1.type]
+ local lv2 = BinaryLevel[op2.type]
+ if lv1 == lv2 then
+ local forward = BinaryForward[lv1]
+ if forward then
+ return op1.start > op2.start
+ else
+ return op1.start < op2.start
+ end
+ else
+ return lv1 < lv2
+ end
+ end)
+ local final
+ for i = #ops, 1, -1 do
+ local n = ops[i]
+ local op = list[n]
+ local left = list[n-1]
+ local right = list[n+1]
+ local exp = {
+ type = 'binary',
+ op = op,
+ start = left.start,
+ finish = right and right.finish or op.finish,
+ [1] = left,
+ [2] = right,
+ }
+ local leftIndex, rightIndex
+ if list[left] then
+ leftIndex = list[left[1]]
+ else
+ leftIndex = n - 1
+ end
+ if list[right] then
+ rightIndex = list[right[2]]
+ else
+ rightIndex = n + 1
+ end
+
+ list[leftIndex] = exp
+ list[rightIndex] = exp
+ list[left] = leftIndex
+ list[right] = rightIndex
+ list[exp] = n
+ final = exp
+
+ checkOpVersion(op)
+ end
+ return final
+ end,
+ Paren = function (start, exp, finish)
+ if exp and exp.type == 'paren' then
+ exp.start = start
+ exp.finish = finish - 1
+ return exp
+ end
+ return {
+ type = 'paren',
+ start = start,
+ finish = finish - 1,
+ exp = exp
+ }
+ end,
+ VarArgs = function (dots)
+ dots.type = 'varargs'
+ return dots
+ end,
+ PackLoopArgs = function (start, list, finish)
+ local list = packList(start, list, finish)
+ if #list == 0 then
+ PushError {
+ type = 'MISS_LOOP_MIN',
+ start = finish,
+ finish = finish,
+ }
+ elseif #list == 1 then
+ PushError {
+ type = 'MISS_LOOP_MAX',
+ start = finish,
+ finish = finish,
+ }
+ end
+ return list
+ end,
+ PackInNameList = function (start, list, finish)
+ local list = packList(start, list, finish)
+ if #list == 0 then
+ PushError {
+ type = 'MISS_NAME',
+ start = start,
+ finish = finish,
+ }
+ end
+ return list
+ end,
+ PackInExpList = function (start, list, finish)
+ local list = packList(start, list, finish)
+ if #list == 0 then
+ PushError {
+ type = 'MISS_EXP',
+ start = start,
+ finish = finish,
+ }
+ end
+ return list
+ end,
+ PackExpList = function (start, list, finish)
+ local list = packList(start, list, finish)
+ return list
+ end,
+ PackNameList = function (start, list, finish)
+ local list = packList(start, list, finish)
+ return list
+ end,
+ Call = function (start, args, finish)
+ return createCall(args, start, finish-1)
+ end,
+ COMMA = function (start)
+ return {
+ type = ',',
+ start = start,
+ finish = start,
+ }
+ end,
+ SEMICOLON = function (start)
+ return {
+ type = ';',
+ start = start,
+ finish = start,
+ }
+ end,
+ DOTS = function (start)
+ return {
+ type = '...',
+ start = start,
+ finish = start + 2,
+ }
+ end,
+ COLON = function (start)
+ return {
+ type = ':',
+ start = start,
+ finish = start,
+ }
+ end,
+ DOT = function (start)
+ return {
+ type = '.',
+ start = start,
+ finish = start,
+ }
+ end,
+ Function = function (functionStart, functionFinish, args, actions, endStart, endFinish)
+ actions.type = 'function'
+ actions.start = functionStart
+ actions.finish = endFinish - 1
+ actions.args = args
+ actions.keyword= {
+ functionStart, functionFinish - 1,
+ endStart, endFinish - 1,
+ }
+ checkMissEnd(functionStart)
+ return actions
+ end,
+ NamedFunction = function (functionStart, functionFinish, name, args, actions, endStart, endFinish)
+ actions.type = 'function'
+ actions.start = functionStart
+ actions.finish = endFinish - 1
+ actions.args = args
+ actions.keyword= {
+ functionStart, functionFinish - 1,
+ endStart, endFinish - 1,
+ }
+ checkMissEnd(functionStart)
+ if not name then
+ return
+ end
+ if name.type == 'getname' then
+ name.type = 'setname'
+ name.value = actions
+ elseif name.type == 'getfield' then
+ name.type = 'setfield'
+ name.value = actions
+ elseif name.type == 'getmethod' then
+ name.type = 'setmethod'
+ name.value = actions
+ end
+ name.range = actions.finish
+ name.vstart = functionStart
+ return name
+ end,
+ LocalFunction = function (start, functionStart, functionFinish, name, args, actions, endStart, endFinish)
+ actions.type = 'function'
+ actions.start = start
+ actions.finish = endFinish - 1
+ actions.args = args
+ actions.keyword= {
+ functionStart, functionFinish - 1,
+ endStart, endFinish - 1,
+ }
+ checkMissEnd(start)
+
+ if not name then
+ return
+ end
+
+ if name.type ~= 'getname' then
+ PushError {
+ type = 'UNEXPECT_LFUNC_NAME',
+ start = name.start,
+ finish = name.finish,
+ }
+ return
+ end
+
+ local loc = createLocal(name, name.start, actions)
+ loc.localfunction = true
+ loc.vstart = functionStart
+
+ return loc
+ end,
+ Table = function (start, tbl, finish)
+ tbl.type = 'table'
+ tbl.start = start
+ tbl.finish = finish - 1
+ local wantField = true
+ local lastStart = start + 1
+ local fieldCount = 0
+ for i = 1, #tbl do
+ local field = tbl[i]
+ if field.type == ',' or field.type == ';' then
+ if wantField then
+ PushError {
+ type = 'MISS_EXP',
+ start = lastStart,
+ finish = field.start - 1,
+ }
+ end
+ wantField = true
+ lastStart = field.finish + 1
+ else
+ if not wantField then
+ PushError {
+ type = 'MISS_SEP_IN_TABLE',
+ start = lastStart,
+ finish = field.start - 1,
+ }
+ end
+ wantField = false
+ lastStart = field.finish + 1
+ fieldCount = fieldCount + 1
+ tbl[fieldCount] = field
+ end
+ end
+ for i = fieldCount + 1, #tbl do
+ tbl[i] = nil
+ end
+ return tbl
+ end,
+ NewField = function (start, field, value, finish)
+ local obj = {
+ type = 'tablefield',
+ start = start,
+ finish = finish-1,
+ field = field,
+ value = value,
+ }
+ if field then
+ field.type = 'field'
+ field.parent = obj
+ end
+ return obj
+ end,
+ NewIndex = function (start, index, value, finish)
+ local obj = {
+ type = 'tableindex',
+ start = start,
+ finish = finish-1,
+ index = index,
+ value = value,
+ }
+ if index then
+ index.parent = obj
+ end
+ return obj
+ end,
+ FuncArgs = function (start, args, finish)
+ args.type = 'funcargs'
+ args.start = start
+ args.finish = finish - 1
+ local lastStart = start + 1
+ local wantName = true
+ local argCount = 0
+ for i = 1, #args do
+ local arg = args[i]
+ local argAst = arg
+ if argAst.type == ',' then
+ if wantName then
+ PushError {
+ type = 'MISS_NAME',
+ start = lastStart,
+ finish = argAst.start-1,
+ }
+ end
+ wantName = true
+ else
+ if not wantName then
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = lastStart-1,
+ finish = argAst.start-1,
+ info = {
+ symbol = ',',
+ }
+ }
+ end
+ wantName = false
+ argCount = argCount + 1
+
+ if argAst.type == '...' then
+ args[argCount] = arg
+ if i < #args then
+ local a = args[i+1]
+ local b = args[#args]
+ PushError {
+ type = 'ARGS_AFTER_DOTS',
+ start = a.start,
+ finish = b.finish,
+ }
+ end
+ break
+ else
+ args[argCount] = createLocal(arg, arg.start)
+ end
+ end
+ lastStart = argAst.finish + 1
+ end
+ for i = argCount + 1, #args do
+ args[i] = nil
+ end
+ if wantName and argCount > 0 then
+ PushError {
+ type = 'MISS_NAME',
+ start = lastStart,
+ finish = finish - 1,
+ }
+ end
+ return args
+ end,
+ Set = function (start, keys, values, finish)
+ for i = 1, #keys do
+ local key = keys[i]
+ if key.type == 'getname' then
+ key.type = 'setname'
+ key.value = getValue(values, i)
+ elseif key.type == 'getfield' then
+ key.type = 'setfield'
+ key.value = getValue(values, i)
+ elseif key.type == 'getindex' then
+ key.type = 'setindex'
+ key.value = getValue(values, i)
+ end
+ if key.value then
+ key.range = key.value.finish
+ end
+ end
+ if values then
+ for i = #keys+1, #values do
+ local value = values[i]
+ PushDiag('redundant-value', {
+ start = value.start,
+ finish = value.finish,
+ max = #keys,
+ passed = #values,
+ })
+ end
+ end
+ return tableUnpack(keys)
+ end,
+ LocalAttr = function (attrs)
+ if #attrs == 0 then
+ return nil
+ end
+ for i = 1, #attrs do
+ local attr = attrs[i]
+ local attrAst = attr
+ attrAst.type = 'localattr'
+ if State.version ~= 'Lua 5.4' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = attrAst.start,
+ finish = attrAst.finish,
+ version = 'Lua 5.4',
+ info = {
+ version = State.version,
+ }
+ }
+ elseif attrAst[1] ~= 'const' and attrAst[1] ~= 'close' then
+ PushError {
+ type = 'UNKNOWN_TAG',
+ start = attrAst.start,
+ finish = attrAst.finish,
+ info = {
+ tag = attrAst[1],
+ }
+ }
+ elseif i > 1 then
+ PushError {
+ type = 'MULTI_TAG',
+ start = attrAst.start,
+ finish = attrAst.finish,
+ info = {
+ tag = attrAst[1],
+ }
+ }
+ end
+ end
+ attrs.start = attrs[1].start
+ attrs.finish = attrs[#attrs].finish
+ return attrs
+ end,
+ LocalName = function (name, attrs)
+ if not name then
+ return name
+ end
+ name.attrs = attrs
+ return name
+ end,
+ Local = function (start, keys, values, finish)
+ for i = 1, #keys do
+ local key = keys[i]
+ local attrs = key.attrs
+ key.attrs = nil
+ local value = getValue(values, i)
+ createLocal(key, finish, value, attrs)
+ end
+ if values then
+ for i = #keys+1, #values do
+ local value = values[i]
+ PushDiag('redundant-value', {
+ start = value.start,
+ finish = value.finish,
+ max = #keys,
+ passed = #values,
+ })
+ end
+ end
+ return tableUnpack(keys)
+ end,
+ Do = function (start, actions, endA, endB)
+ actions.type = 'do'
+ actions.start = start
+ actions.finish = endB - 1
+ actions.keyword= {
+ start, start + #'do' - 1,
+ endA , endB - 1,
+ }
+ checkMissEnd(start)
+ return actions
+ end,
+ Break = function (start, finish)
+ return {
+ type = 'break',
+ start = start,
+ finish = finish - 1,
+ }
+ end,
+ Return = function (start, exps, finish)
+ exps.type = 'return'
+ exps.start = start
+ exps.finish = finish - 1
+ return exps
+ end,
+ Label = function (start, name, finish)
+ if State.version == 'Lua 5.1' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = start,
+ finish = finish - 1,
+ version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return
+ end
+ if not name then
+ return nil
+ end
+ name.type = 'label'
+ return name
+ end,
+ GoTo = function (start, name, finish)
+ if State.version == 'Lua 5.1' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = start,
+ finish = finish - 1,
+ version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return
+ end
+ if not name then
+ return nil
+ end
+ name.type = 'goto'
+ return name
+ end,
+ IfBlock = function (ifStart, ifFinish, exp, thenStart, thenFinish, actions, finish)
+ actions.type = 'ifblock'
+ actions.start = ifStart
+ actions.finish = finish - 1
+ actions.filter = exp
+ actions.keyword= {
+ ifStart, ifFinish - 1,
+ thenStart, thenFinish - 1,
+ }
+ return actions
+ end,
+ ElseIfBlock = function (elseifStart, elseifFinish, exp, thenStart, thenFinish, actions, finish)
+ actions.type = 'elseifblock'
+ actions.start = elseifStart
+ actions.finish = finish - 1
+ actions.filter = exp
+ actions.keyword= {
+ elseifStart, elseifFinish - 1,
+ thenStart, thenFinish - 1,
+ }
+ return actions
+ end,
+ ElseBlock = function (elseStart, elseFinish, actions, finish)
+ actions.type = 'elseblock'
+ actions.start = elseStart
+ actions.finish = finish - 1
+ actions.keyword= {
+ elseStart, elseFinish - 1,
+ }
+ return actions
+ end,
+ If = function (start, blocks, endStart, endFinish)
+ blocks.type = 'if'
+ blocks.start = start
+ blocks.finish = endFinish - 1
+ local hasElse
+ for i = 1, #blocks do
+ local block = blocks[i]
+ if i == 1 and block.type ~= 'ifblock' then
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = block.start,
+ finish = block.start,
+ info = {
+ symbol = 'if',
+ }
+ }
+ end
+ if hasElse then
+ PushError {
+ type = 'BLOCK_AFTER_ELSE',
+ start = block.start,
+ finish = block.finish,
+ }
+ end
+ if block.type == 'elseblock' then
+ hasElse = true
+ end
+ end
+ checkMissEnd(start)
+ return blocks
+ end,
+ Loop = function (forA, forB, arg, steps, doA, doB, blockStart, block, endA, endB)
+ local loc = createLocal(arg, blockStart, steps[1])
+ block.type = 'loop'
+ block.start = forA
+ block.finish = endB - 1
+ block.loc = loc
+ block.max = steps[2]
+ block.step = steps[3]
+ block.keyword= {
+ forA, forB - 1,
+ doA , doB - 1,
+ endA, endB - 1,
+ }
+ checkMissEnd(forA)
+ return block
+ end,
+ In = function (forA, forB, keys, inA, inB, exp, doA, doB, blockStart, block, endA, endB)
+ local func = tableRemove(exp, 1)
+ block.type = 'in'
+ block.start = forA
+ block.finish = endB - 1
+ block.keys = keys
+ block.keyword= {
+ forA, forB - 1,
+ inA , inB - 1,
+ doA , doB - 1,
+ endA, endB - 1,
+ }
+
+ local values
+ if func then
+ local call = createCall(exp, func.finish + 1, exp.finish)
+ call.node = func
+ call.start = func.start
+ func.next = call
+ values = { call }
+ keys.range = call.finish
+ end
+ for i = 1, #keys do
+ local loc = keys[i]
+ if values then
+ createLocal(loc, blockStart, getValue(values, i))
+ else
+ createLocal(loc, blockStart)
+ end
+ end
+ checkMissEnd(forA)
+ return block
+ end,
+ While = function (whileA, whileB, filter, doA, doB, block, endA, endB)
+ block.type = 'while'
+ block.start = whileA
+ block.finish = endB - 1
+ block.filter = filter
+ block.keyword= {
+ whileA, whileB - 1,
+ doA , doB - 1,
+ endA , endB - 1,
+ }
+ checkMissEnd(whileA)
+ return block
+ end,
+ Repeat = function (repeatA, repeatB, block, untilA, untilB, filter, finish)
+ block.type = 'repeat'
+ block.start = repeatA
+ block.finish = finish
+ block.filter = filter
+ block.keyword= {
+ repeatA, repeatB - 1,
+ untilA , untilB - 1,
+ }
+ return block
+ end,
+ Lua = function (start, actions, finish)
+ actions.type = 'main'
+ actions.start = start
+ actions.finish = finish - 1
+ return actions
+ end,
+
+ -- 捕获错误
+ UnknownSymbol = function (start, symbol)
+ PushError {
+ type = 'UNKNOWN_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ info = {
+ symbol = symbol,
+ }
+ }
+ return
+ end,
+ UnknownAction = function (start, symbol)
+ PushError {
+ type = 'UNKNOWN_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ info = {
+ symbol = symbol,
+ }
+ }
+ end,
+ DirtyName = function (pos)
+ PushError {
+ type = 'MISS_NAME',
+ start = pos,
+ finish = pos,
+ }
+ return nil
+ end,
+ DirtyExp = function (pos)
+ PushError {
+ type = 'MISS_EXP',
+ start = pos,
+ finish = pos,
+ }
+ return nil
+ end,
+ MissExp = function (pos)
+ PushError {
+ type = 'MISS_EXP',
+ start = pos,
+ finish = pos,
+ }
+ end,
+ MissExponent = function (start, finish)
+ PushError {
+ type = 'MISS_EXPONENT',
+ start = start,
+ finish = finish - 1,
+ }
+ end,
+ MissQuote1 = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '"'
+ }
+ }
+ end,
+ MissQuote2 = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = "'"
+ }
+ }
+ end,
+ MissEscX = function (pos)
+ PushError {
+ type = 'MISS_ESC_X',
+ start = pos-2,
+ finish = pos+1,
+ }
+ end,
+ MissTL = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '{',
+ }
+ }
+ end,
+ MissTR = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '}',
+ }
+ }
+ end,
+ MissBR = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = ']',
+ }
+ }
+ end,
+ MissPL = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '(',
+ }
+ }
+ end,
+ MissPR = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = ')',
+ }
+ }
+ end,
+ ErrEsc = function (pos)
+ PushError {
+ type = 'ERR_ESC',
+ start = pos-1,
+ finish = pos,
+ }
+ end,
+ MustX16 = function (pos, str)
+ PushError {
+ type = 'MUST_X16',
+ start = pos,
+ finish = pos + #str - 1,
+ }
+ end,
+ MissAssign = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '=',
+ }
+ }
+ end,
+ MissTableSep = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = ','
+ }
+ }
+ end,
+ MissField = function (pos)
+ PushError {
+ type = 'MISS_FIELD',
+ start = pos,
+ finish = pos,
+ }
+ end,
+ MissMethod = function (pos)
+ PushError {
+ type = 'MISS_METHOD',
+ start = pos,
+ finish = pos,
+ }
+ end,
+ MissLabel = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '::',
+ }
+ }
+ end,
+ MissEnd = function (pos)
+ State.MissEndErr = PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'end',
+ }
+ }
+ return pos, pos
+ end,
+ MissDo = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'do',
+ }
+ }
+ return pos, pos
+ end,
+ MissComma = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = ',',
+ }
+ }
+ end,
+ MissIn = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'in',
+ }
+ }
+ return pos, pos
+ end,
+ MissUntil = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'until',
+ }
+ }
+ return pos, pos
+ end,
+ MissThen = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'then',
+ }
+ }
+ return pos, pos
+ end,
+ MissName = function (pos)
+ PushError {
+ type = 'MISS_NAME',
+ start = pos,
+ finish = pos,
+ }
+ end,
+ ExpInAction = function (start, exp, finish)
+ PushError {
+ type = 'EXP_IN_ACTION',
+ start = start,
+ finish = finish - 1,
+ }
+ -- 当exp为nil时,不能返回任何值,否则会产生带洞的actionlist
+ if exp then
+ return exp
+ else
+ return
+ end
+ end,
+ MissIf = function (start, block)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = start,
+ finish = start,
+ info = {
+ symbol = 'if',
+ }
+ }
+ return block
+ end,
+ MissGT = function (start)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = start,
+ finish = start,
+ info = {
+ symbol = '>'
+ }
+ }
+ end,
+ ErrAssign = function (start, finish)
+ PushError {
+ type = 'ERR_ASSIGN_AS_EQ',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_ASSIGN_AS_EQ',
+ {
+ start = start,
+ finish = finish - 1,
+ text = '=',
+ }
+ }
+ }
+ end,
+ ErrEQ = function (start, finish)
+ PushError {
+ type = 'ERR_EQ_AS_ASSIGN',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_EQ_AS_ASSIGN',
+ {
+ start = start,
+ finish = finish - 1,
+ text = '==',
+ }
+ }
+ }
+ return '=='
+ end,
+ ErrUEQ = function (start, finish)
+ PushError {
+ type = 'ERR_UEQ',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_UEQ',
+ {
+ start = start,
+ finish = finish - 1,
+ text = '~=',
+ }
+ }
+ }
+ return '=='
+ end,
+ ErrThen = function (start, finish)
+ PushError {
+ type = 'ERR_THEN_AS_DO',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_THEN_AS_DO',
+ {
+ start = start,
+ finish = finish - 1,
+ text = 'then',
+ }
+ }
+ }
+ return start, finish
+ end,
+ ErrDo = function (start, finish)
+ PushError {
+ type = 'ERR_DO_AS_THEN',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_DO_AS_THEN',
+ {
+ start = start,
+ finish = finish - 1,
+ text = 'do',
+ }
+ }
+ }
+ return start, finish
+ end,
+}
+
+local function init(state)
+ State = state
+ PushError = state.pushError
+ PushDiag = state.pushDiag
+ PushComment = state.pushComment
+end
+
+local function close()
+ State = nil
+ PushError = function (...) end
+ PushDiag = function (...) end
+ PushComment = function (...) end
+end
+
+return {
+ defs = Defs,
+ init = init,
+ close = close,
+}
diff --git a/script/parser/calcline.lua b/script/parser/calcline.lua
new file mode 100644
index 00000000..2e944167
--- /dev/null
+++ b/script/parser/calcline.lua
@@ -0,0 +1,94 @@
+local m = require 'lpeglabel'
+local util = require 'utility'
+
+local row
+local fl
+local NL = (m.P'\r\n' + m.S'\r\n') * m.Cp() / function (pos)
+ row = row + 1
+ fl = pos
+end
+local ROWCOL = (NL + m.P(1))^0
+local function rowcol(str, n)
+ row = 1
+ fl = 1
+ ROWCOL:match(str:sub(1, n))
+ local col = n - fl + 1
+ return row, col
+end
+
+local function rowcol_utf8(str, n)
+ row = 1
+ fl = 1
+ ROWCOL:match(str:sub(1, n))
+ return row, util.utf8Len(str, fl, n)
+end
+
+local function position(str, _row, _col)
+ local cur = 1
+ local row = 1
+ while true do
+ if row == _row then
+ return cur + _col - 1
+ elseif row > _row then
+ return cur - 1
+ end
+ local pos = str:find('[\r\n]', cur)
+ if not pos then
+ return #str
+ end
+ row = row + 1
+ if str:sub(pos, pos+1) == '\r\n' then
+ cur = pos + 2
+ else
+ cur = pos + 1
+ end
+ end
+end
+
+local function position_utf8(str, _row, _col)
+ local cur = 1
+ local row = 1
+ while true do
+ if row == _row then
+ return utf8.offset(str, _col, cur)
+ elseif row > _row then
+ return cur - 1
+ end
+ local pos = str:find('[\r\n]', cur)
+ if not pos then
+ return #str
+ end
+ row = row + 1
+ if str:sub(pos, pos+1) == '\r\n' then
+ cur = pos + 2
+ else
+ cur = pos + 1
+ end
+ end
+end
+
+local NL = m.P'\r\n' + m.S'\r\n'
+
+local function line(str, row)
+ local count = 0
+ local res
+ local LINE = m.Cmt((1 - NL)^0, function (_, _, c)
+ count = count + 1
+ if count == row then
+ res = c
+ return false
+ end
+ return true
+ end)
+ local MATCH = (LINE * NL)^0 * LINE
+ MATCH:match(str)
+ return res
+end
+
+return {
+ rowcol = rowcol,
+ rowcol_utf8 = rowcol_utf8,
+ position = position,
+ position_utf8 = position_utf8,
+ line = line,
+}
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
new file mode 100644
index 00000000..2c7172e8
--- /dev/null
+++ b/script/parser/compile.lua
@@ -0,0 +1,561 @@
+local guide = require 'parser.guide'
+local type = type
+
+local specials = {
+ ['_G'] = true,
+ ['rawset'] = true,
+ ['rawget'] = true,
+ ['setmetatable'] = true,
+ ['require'] = true,
+ ['dofile'] = true,
+ ['loadfile'] = true,
+ ['pcall'] = true,
+ ['xpcall'] = true,
+}
+
+_ENV = nil
+
+local LocalLimit = 200
+local pushError, Compile, CompileBlock, Block, GoToTag, ENVMode, Compiled, LocalCount, Version, Root, Options
+
+local function addRef(node, obj)
+ if not node.ref then
+ node.ref = {}
+ end
+ node.ref[#node.ref+1] = obj
+ obj.node = node
+end
+
+local function addSpecial(name, obj)
+ if not Root.specials then
+ Root.specials = {}
+ end
+ if not Root.specials[name] then
+ Root.specials[name] = {}
+ end
+ Root.specials[name][#Root.specials[name]+1] = obj
+ obj.special = name
+end
+
+local vmMap = {
+ ['getname'] = function (obj)
+ local loc = guide.getLocal(obj, obj[1], obj.start)
+ if loc then
+ obj.type = 'getlocal'
+ obj.loc = loc
+ addRef(loc, obj)
+ if loc.special then
+ addSpecial(loc.special, obj)
+ end
+ else
+ obj.type = 'getglobal'
+ local node = guide.getLocal(obj, ENVMode, obj.start)
+ if node then
+ addRef(node, obj)
+ end
+ local name = obj[1]
+ if specials[name] then
+ addSpecial(name, obj)
+ elseif Options and Options.special then
+ local asName = Options.special[name]
+ if specials[asName] then
+ addSpecial(asName, obj)
+ end
+ end
+ end
+ return obj
+ end,
+ ['getfield'] = function (obj)
+ Compile(obj.node, obj)
+ end,
+ ['call'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.args, obj)
+ end,
+ ['callargs'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ end,
+ ['binary'] = function (obj)
+ Compile(obj[1], obj)
+ Compile(obj[2], obj)
+ end,
+ ['unary'] = function (obj)
+ Compile(obj[1], obj)
+ end,
+ ['varargs'] = function (obj)
+ local func = guide.getParentFunction(obj)
+ if func then
+ local index, vararg = guide.getFunctionVarArgs(func)
+ if not index then
+ pushError {
+ type = 'UNEXPECT_DOTS',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ if vararg then
+ if not vararg.ref then
+ vararg.ref = {}
+ end
+ vararg.ref[#vararg.ref+1] = obj
+ end
+ end
+ end,
+ ['paren'] = function (obj)
+ Compile(obj.exp, obj)
+ end,
+ ['getindex'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.index, obj)
+ end,
+ ['setindex'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.index, obj)
+ Compile(obj.value, obj)
+ end,
+ ['getmethod'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.method, obj)
+ end,
+ ['setmethod'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.method, obj)
+ local value = obj.value
+ value.localself = {
+ type = 'local',
+ start = 0,
+ finish = 0,
+ method = obj,
+ effect = obj.finish,
+ tag = 'self',
+ [1] = 'self',
+ }
+ Compile(value, obj)
+ end,
+ ['function'] = function (obj)
+ local lastBlock = Block
+ local LastLocalCount = LocalCount
+ Block = obj
+ LocalCount = 0
+ if obj.localself then
+ Compile(obj.localself, obj)
+ obj.localself = nil
+ end
+ Compile(obj.args, obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ Block = lastBlock
+ LocalCount = LastLocalCount
+ end,
+ ['funcargs'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ end,
+ ['table'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ end,
+ ['tablefield'] = function (obj)
+ Compile(obj.value, obj)
+ end,
+ ['tableindex'] = function (obj)
+ Compile(obj.index, obj)
+ Compile(obj.value, obj)
+ end,
+ ['index'] = function (obj)
+ Compile(obj.index, obj)
+ end,
+ ['select'] = function (obj)
+ local vararg = obj.vararg
+ if vararg.parent then
+ if not vararg.extParent then
+ vararg.extParent = {}
+ end
+ vararg.extParent[#vararg.extParent+1] = obj
+ else
+ Compile(vararg, obj)
+ end
+ end,
+ ['setname'] = function (obj)
+ Compile(obj.value, obj)
+ local loc = guide.getLocal(obj, obj[1], obj.start)
+ if loc then
+ obj.type = 'setlocal'
+ obj.loc = loc
+ addRef(loc, obj)
+ if loc.attrs then
+ local const
+ for i = 1, #loc.attrs do
+ local attr = loc.attrs[i][1]
+ if attr == 'const'
+ or attr == 'close' then
+ const = true
+ break
+ end
+ end
+ if const then
+ pushError {
+ type = 'SET_CONST',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ end
+ else
+ obj.type = 'setglobal'
+ local node = guide.getLocal(obj, ENVMode, obj.start)
+ if node then
+ addRef(node, obj)
+ end
+ local name = obj[1]
+ if specials[name] then
+ addSpecial(name, obj)
+ elseif Options and Options.special then
+ local asName = Options.special[name]
+ if specials[asName] then
+ addSpecial(asName, obj)
+ end
+ end
+ end
+ end,
+ ['local'] = function (obj)
+ local attrs = obj.attrs
+ if attrs then
+ for i = 1, #attrs do
+ Compile(attrs[i], obj)
+ end
+ end
+ if Block then
+ if not Block.locals then
+ Block.locals = {}
+ end
+ Block.locals[#Block.locals+1] = obj
+ LocalCount = LocalCount + 1
+ if LocalCount > LocalLimit then
+ pushError {
+ type = 'LOCAL_LIMIT',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ end
+ if obj.localfunction then
+ obj.localfunction = nil
+ end
+ Compile(obj.value, obj)
+ if obj.value and obj.value.special then
+ addSpecial(obj.value.special, obj)
+ end
+ end,
+ ['setfield'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.value, obj)
+ end,
+ ['do'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['return'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ if Block and Block[#Block] ~= obj then
+ pushError {
+ type = 'ACTION_AFTER_RETURN',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ local func = guide.getParentFunction(obj)
+ if func then
+ if not func.returns then
+ func.returns = {}
+ end
+ func.returns[#func.returns+1] = obj
+ end
+ end,
+ ['label'] = function (obj)
+ local block = guide.getBlock(obj)
+ if block then
+ if not block.labels then
+ block.labels = {}
+ end
+ local name = obj[1]
+ local label = guide.getLabel(block, name)
+ if label then
+ if Version == 'Lua 5.4'
+ or block == guide.getBlock(label) then
+ pushError {
+ type = 'REDEFINED_LABEL',
+ start = obj.start,
+ finish = obj.finish,
+ relative = {
+ {
+ label.start,
+ label.finish,
+ }
+ }
+ }
+ end
+ end
+ block.labels[name] = obj
+ end
+ end,
+ ['goto'] = function (obj)
+ GoToTag[#GoToTag+1] = obj
+ end,
+ ['if'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ end,
+ ['ifblock'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ Compile(obj.filter, obj)
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['elseifblock'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ Compile(obj.filter, obj)
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['elseblock'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['loop'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ Compile(obj.loc, obj)
+ Compile(obj.max, obj)
+ Compile(obj.step, obj)
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['in'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ local keys = obj.keys
+ for i = 1, #keys do
+ Compile(keys[i], obj)
+ end
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['while'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ Compile(obj.filter, obj)
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['repeat'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ CompileBlock(obj, obj)
+ Compile(obj.filter, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['break'] = function (obj)
+ local block = guide.getBreakBlock(obj)
+ if block then
+ if not block.breaks then
+ block.breaks = {}
+ end
+ block.breaks[#block.breaks+1] = obj
+ else
+ pushError {
+ type = 'BREAK_OUTSIDE',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ end,
+ ['main'] = function (obj)
+ Block = obj
+ Compile({
+ type = 'local',
+ start = 0,
+ finish = 0,
+ effect = 0,
+ tag = '_ENV',
+ special= '_G',
+ [1] = ENVMode,
+ }, obj)
+ --- _ENV 是上值,不计入局部变量计数
+ LocalCount = 0
+ CompileBlock(obj, obj)
+ Block = nil
+ end,
+}
+
+function CompileBlock(obj, parent)
+ for i = 1, #obj do
+ local act = obj[i]
+ local f = vmMap[act.type]
+ if f then
+ act.parent = parent
+ f(act)
+ end
+ end
+end
+
+function Compile(obj, parent)
+ if not obj then
+ return nil
+ end
+ if Compiled[obj] then
+ return
+ end
+ Compiled[obj] = true
+ obj.parent = parent
+ local f = vmMap[obj.type]
+ if not f then
+ return
+ end
+ f(obj)
+end
+
+local function compileGoTo(obj)
+ local name = obj[1]
+ local label = guide.getLabel(obj, name)
+ if not label then
+ pushError {
+ type = 'NO_VISIBLE_LABEL',
+ start = obj.start,
+ finish = obj.finish,
+ info = {
+ label = name,
+ }
+ }
+ return
+ end
+ if not label.ref then
+ label.ref = {}
+ end
+ label.ref[#label.ref+1] = obj
+ obj.node = label
+
+ -- 如果有局部变量在 goto 与 label 之间声明,
+ -- 并在 label 之后使用,则算作语法错误
+
+ -- 如果 label 在 goto 之前声明,那么不会有中间声明的局部变量
+ if obj.start > label.start then
+ return
+ end
+
+ local block = guide.getBlock(obj)
+ local locals = block and block.locals
+ if not locals then
+ return
+ end
+
+ for i = 1, #locals do
+ local loc = locals[i]
+ -- 检查局部变量声明位置为 goto 与 label 之间
+ if loc.start < obj.start or loc.finish > label.finish then
+ goto CONTINUE
+ end
+ -- 检查局部变量的使用位置在 label 之后
+ local refs = loc.ref
+ if not refs then
+ goto CONTINUE
+ end
+ for j = 1, #refs do
+ local ref = refs[j]
+ if ref.finish > label.finish then
+ pushError {
+ type = 'JUMP_LOCAL_SCOPE',
+ start = obj.start,
+ finish = obj.finish,
+ info = {
+ loc = loc[1],
+ },
+ relative = {
+ {
+ start = label.start,
+ finish = label.finish,
+ },
+ {
+ start = loc.start,
+ finish = loc.finish,
+ }
+ },
+ }
+ return
+ end
+ end
+ ::CONTINUE::
+ end
+end
+
+local function PostCompile()
+ for i = 1, #GoToTag do
+ compileGoTo(GoToTag[i])
+ end
+end
+
+return function (self, lua, mode, version, options)
+ local state, err = self:parse(lua, mode, version)
+ if not state then
+ return nil, err
+ end
+ pushError = state.pushError
+ if version == 'Lua 5.1' or version == 'LuaJIT' then
+ ENVMode = '@fenv'
+ else
+ ENVMode = '_ENV'
+ end
+ Compiled = {}
+ GoToTag = {}
+ LocalCount = 0
+ Version = version
+ Root = state.ast
+ Root.state = state
+ Options = options
+ state.ENVMode = ENVMode
+ if type(state.ast) == 'table' then
+ Compile(state.ast)
+ end
+ PostCompile()
+ Compiled = nil
+ GoToTag = nil
+ return state
+end
diff --git a/script/parser/grammar.lua b/script/parser/grammar.lua
new file mode 100644
index 00000000..06dae246
--- /dev/null
+++ b/script/parser/grammar.lua
@@ -0,0 +1,538 @@
+local re = require 'parser.relabel'
+local m = require 'lpeglabel'
+local ast = require 'parser.ast'
+
+local scriptBuf = ''
+local compiled = {}
+local defs = ast.defs
+
+-- goto 可以作为名字,合法性之后处理
+local RESERVED = {
+ ['and'] = true,
+ ['break'] = true,
+ ['do'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['end'] = true,
+ ['false'] = true,
+ ['for'] = true,
+ ['function'] = true,
+ ['if'] = true,
+ ['in'] = true,
+ ['local'] = true,
+ ['nil'] = true,
+ ['not'] = true,
+ ['or'] = true,
+ ['repeat'] = true,
+ ['return'] = true,
+ ['then'] = true,
+ ['true'] = true,
+ ['until'] = true,
+ ['while'] = true,
+}
+
+defs.nl = (m.P'\r\n' + m.S'\r\n')
+defs.s = m.S' \t'
+defs.S = - defs.s
+defs.ea = '\a'
+defs.eb = '\b'
+defs.ef = '\f'
+defs.en = '\n'
+defs.er = '\r'
+defs.et = '\t'
+defs.ev = '\v'
+defs['nil'] = m.Cp() / function () return nil end
+defs['false'] = m.Cp() / function () return false end
+defs.NotReserved = function (_, _, str)
+ if RESERVED[str] then
+ return false
+ end
+ return true
+end
+defs.Reserved = function (_, _, str)
+ if RESERVED[str] then
+ return true
+ end
+ return false
+end
+defs.None = function () end
+defs.np = m.Cp() / function (n) return n+1 end
+
+m.setmaxstack(1000)
+
+local eof = re.compile '!. / %{SYNTAX_ERROR}'
+
+local function grammar(tag)
+ return function (script)
+ scriptBuf = script .. '\r\n' .. scriptBuf
+ compiled[tag] = re.compile(scriptBuf, defs) * eof
+ end
+end
+
+local function errorpos(pos, err)
+ return {
+ type = 'UNKNOWN',
+ start = pos or 0,
+ finish = pos or 0,
+ err = err,
+ }
+end
+
+grammar 'Comment' [[
+Comment <- LongComment
+ / '--' ShortComment
+LongComment <- ('--[' {} {:eq: '='* :} {} '['
+ {(!CommentClose .)*}
+ (CommentClose / {}))
+ -> LongComment
+ / (
+ {} '/*' {}
+ (!'*/' .)*
+ {} '*/' {}
+ )
+ -> CLongComment
+CommentClose <- ']' =eq ']'
+ShortComment <- ({} {(!%nl .)*} {})
+ -> ShortComment
+]]
+
+grammar 'Sp' [[
+Sp <- (Comment / %nl / %s)*
+Sps <- (Comment / %nl / %s)+
+]]
+
+grammar 'Common' [[
+Word <- [a-zA-Z0-9_]
+Cut <- !Word
+X16 <- [a-fA-F0-9]
+Rest <- (!%nl .)*
+
+AND <- Sp {'and'} Cut
+BREAK <- Sp 'break' Cut
+FALSE <- Sp 'false' Cut
+GOTO <- Sp 'goto' Cut
+LOCAL <- Sp 'local' Cut
+NIL <- Sp 'nil' Cut
+NOT <- Sp 'not' Cut
+OR <- Sp {'or'} Cut
+RETURN <- Sp 'return' Cut
+TRUE <- Sp 'true' Cut
+
+DO <- Sp {} 'do' {} Cut
+ / Sp({} 'then' {} Cut) -> ErrDo
+IF <- Sp {} 'if' {} Cut
+ELSE <- Sp {} 'else' {} Cut
+ELSEIF <- Sp {} 'elseif' {} Cut
+END <- Sp {} 'end' {} Cut
+FOR <- Sp {} 'for' {} Cut
+FUNCTION <- Sp {} 'function' {} Cut
+IN <- Sp {} 'in' {} Cut
+REPEAT <- Sp {} 'repeat' {} Cut
+THEN <- Sp {} 'then' {} Cut
+ / Sp({} 'do' {} Cut) -> ErrThen
+UNTIL <- Sp {} 'until' {} Cut
+WHILE <- Sp {} 'while' {} Cut
+
+
+Esc <- '\' -> ''
+ EChar
+EChar <- 'a' -> ea
+ / 'b' -> eb
+ / 'f' -> ef
+ / 'n' -> en
+ / 'r' -> er
+ / 't' -> et
+ / 'v' -> ev
+ / '\'
+ / '"'
+ / "'"
+ / %nl
+ / ('z' (%nl / %s)*) -> ''
+ / ({} 'x' {X16 X16}) -> Char16
+ / ([0-9] [0-9]? [0-9]?) -> Char10
+ / ('u{' {} {Word*} '}') -> CharUtf8
+ -- 错误处理
+ / 'x' {} -> MissEscX
+ / 'u' !'{' {} -> MissTL
+ / 'u{' Word* !'}' {} -> MissTR
+ / {} -> ErrEsc
+
+BOR <- Sp {'|'}
+BXOR <- Sp {'~'} !'='
+BAND <- Sp {'&'}
+Bshift <- Sp {BshiftList}
+BshiftList <- '<<'
+ / '>>'
+Concat <- Sp {'..'}
+Adds <- Sp {AddsList}
+AddsList <- '+'
+ / '-'
+Muls <- Sp {MulsList}
+MulsList <- '*'
+ / '//'
+ / '/'
+ / '%'
+Unary <- Sp {} {UnaryList}
+UnaryList <- NOT
+ / '#'
+ / '-'
+ / '~' !'='
+POWER <- Sp {'^'}
+
+BinaryOp <-( Sp {} {'or'} Cut
+ / Sp {} {'and'} Cut
+ / Sp {} {'<=' / '>=' / '<'!'<' / '>'!'>' / '~=' / '=='}
+ / Sp {} ({} '=' {}) -> ErrEQ
+ / Sp {} ({} '!=' {}) -> ErrUEQ
+ / Sp {} {'|'}
+ / Sp {} {'~'}
+ / Sp {} {'&'}
+ / Sp {} {'<<' / '>>'}
+ / Sp {} {'..'} !'.'
+ / Sp {} {'+' / '-'}
+ / Sp {} {'*' / '//' / '/' / '%'}
+ / Sp {} {'^'}
+ )-> BinaryOp
+UnaryOp <-( Sp {} {'not' Cut / '#' / '~' !'=' / '-' !'-'}
+ )-> UnaryOp
+
+PL <- Sp '('
+PR <- Sp ')'
+BL <- Sp '[' !'[' !'='
+BR <- Sp ']'
+TL <- Sp '{'
+TR <- Sp '}'
+COMMA <- Sp ({} ',')
+ -> COMMA
+SEMICOLON <- Sp ({} ';')
+ -> SEMICOLON
+DOTS <- Sp ({} '...')
+ -> DOTS
+DOT <- Sp ({} '.' !'.')
+ -> DOT
+COLON <- Sp ({} ':' !':')
+ -> COLON
+LABEL <- Sp '::'
+ASSIGN <- Sp '=' !'='
+AssignOrEQ <- Sp ({} '==' {})
+ -> ErrAssign
+ / Sp '='
+
+DirtyBR <- BR / {} -> MissBR
+DirtyTR <- TR / {} -> MissTR
+DirtyPR <- PR / {} -> MissPR
+DirtyLabel <- LABEL / {} -> MissLabel
+NeedEnd <- END / {} -> MissEnd
+NeedDo <- DO / {} -> MissDo
+NeedAssign <- ASSIGN / {} -> MissAssign
+NeedComma <- COMMA / {} -> MissComma
+NeedIn <- IN / {} -> MissIn
+NeedUntil <- UNTIL / {} -> MissUntil
+NeedThen <- THEN / {} -> MissThen
+]]
+
+grammar 'Nil' [[
+Nil <- Sp ({} -> Nil) NIL
+]]
+
+grammar 'Boolean' [[
+Boolean <- Sp ({} -> True) TRUE
+ / Sp ({} -> False) FALSE
+]]
+
+grammar 'String' [[
+String <- Sp ({} StringDef {})
+ -> String
+StringDef <- {'"'}
+ {~(Esc / !%nl !'"' .)*~} -> 1
+ ('"' / {} -> MissQuote1)
+ / {"'"}
+ {~(Esc / !%nl !"'" .)*~} -> 1
+ ("'" / {} -> MissQuote2)
+ / ('[' {} {:eq: '='* :} {} '[' %nl?
+ {(!StringClose .)*} -> 1
+ (StringClose / {}))
+ -> LongString
+StringClose <- ']' =eq ']'
+]]
+
+grammar 'Number' [[
+Number <- Sp ({} {NumberDef} {}) -> Number
+ NumberSuffix?
+ ErrNumber?
+NumberDef <- Number16 / Number10
+NumberSuffix<- ({} {[uU]? [lL] [lL]}) -> FFINumber
+ / ({} {[iI]}) -> ImaginaryNumber
+ErrNumber <- ({} {([0-9a-zA-Z] / '.')+}) -> UnknownSymbol
+
+Number10 <- Float10 Float10Exp?
+ / Integer10 Float10? Float10Exp?
+Integer10 <- [0-9]+ ('.' [0-9]*)?
+Float10 <- '.' [0-9]+
+Float10Exp <- [eE] [+-]? [0-9]+
+ / ({} [eE] [+-]? {}) -> MissExponent
+
+Number16 <- '0' [xX] Float16 Float16Exp?
+ / '0' [xX] Integer16 Float16? Float16Exp?
+Integer16 <- X16+ ('.' X16*)?
+ / ({} {Word*}) -> MustX16
+Float16 <- '.' X16+
+ / '.' ({} {Word*}) -> MustX16
+Float16Exp <- [pP] [+-]? [0-9]+
+ / ({} [pP] [+-]? {}) -> MissExponent
+]]
+
+grammar 'Name' [[
+Name <- Sp ({} NameBody {})
+ -> Name
+NameBody <- {[a-zA-Z_] [a-zA-Z0-9_]*}
+FreeName <- Sp ({} {NameBody=>NotReserved} {})
+ -> Name
+KeyWord <- Sp NameBody=>Reserved
+MustName <- Name / DirtyName
+DirtyName <- {} -> DirtyName
+]]
+
+grammar 'Exp' [[
+Exp <- (UnUnit BinUnit*)
+ -> Binary
+BinUnit <- (BinaryOp UnUnit?)
+ -> SubBinary
+UnUnit <- ExpUnit
+ / (UnaryOp+ (ExpUnit / MissExp))
+ -> Unary
+ExpUnit <- Nil
+ / Boolean
+ / String
+ / Number
+ / Dots
+ / Table
+ / Function
+ / Simple
+
+Simple <- {| Prefix (Sp Suffix)* |}
+ -> Simple
+Prefix <- Sp ({} PL DirtyExp DirtyPR {})
+ -> Paren
+ / Single
+Single <- FreeName
+ -> Single
+Suffix <- SuffixWithoutCall
+ / ({} PL SuffixCall DirtyPR {})
+ -> Call
+SuffixCall <- Sp ({} {| (COMMA / Exp)+ |} {})
+ -> PackExpList
+ / %nil
+SuffixWithoutCall
+ <- (DOT (Name / MissField))
+ -> GetField
+ / ({} BL DirtyExp DirtyBR {})
+ -> GetIndex
+ / (COLON (Name / MissMethod) NeedCall)
+ -> GetMethod
+ / ({} {| Table |} {})
+ -> Call
+ / ({} {| String |} {})
+ -> Call
+NeedCall <- (!(Sp CallStart) {} -> MissPL)?
+MissField <- {} -> MissField
+MissMethod <- {} -> MissMethod
+CallStart <- PL
+ / TL
+ / '"'
+ / "'"
+ / '[' '='* '['
+
+DirtyExp <- Exp
+ / {} -> DirtyExp
+MaybeExp <- Exp / MissExp
+MissExp <- {} -> MissExp
+ExpList <- Sp {| MaybeExp (Sp ',' MaybeExp)* |}
+
+Dots <- DOTS
+ -> VarArgs
+
+Table <- Sp ({} TL {| TableField* |} DirtyTR {})
+ -> Table
+TableField <- COMMA
+ / SEMICOLON
+ / NewIndex
+ / NewField
+ / Exp
+Index <- BL DirtyExp DirtyBR
+NewIndex <- Sp ({} Index NeedAssign DirtyExp {})
+ -> NewIndex
+NewField <- Sp ({} MustName ASSIGN DirtyExp {})
+ -> NewField
+
+Function <- FunctionBody
+ -> Function
+FuncArgs <- Sp ({} PL {| FuncArg+ |} DirtyPR {})
+ -> FuncArgs
+ / PL DirtyPR %nil
+FuncArgsMiss<- {} -> MissPL DirtyPR %nil
+FuncArg <- DOTS
+ / Name
+ / COMMA
+FunctionBody<- FUNCTION FuncArgs
+ {| (!END Action)* |}
+ NeedEnd
+ / FUNCTION FuncArgsMiss
+ {| %nil |}
+ NeedEnd
+
+-- 纯占位,修改了 `relabel.lua` 使重复定义不抛错
+Action <- !END .
+]]
+
+grammar 'Action' [[
+Action <- Sp (CrtAction / UnkAction)
+CrtAction <- Semicolon
+ / Do
+ / Break
+ / Return
+ / Label
+ / GoTo
+ / If
+ / For
+ / While
+ / Repeat
+ / NamedFunction
+ / LocalFunction
+ / Local
+ / Set
+ / Call
+ / ExpInAction
+UnkAction <- ({} {Word+})
+ -> UnknownAction
+ / ({} '//' {} (LongComment / ShortComment))
+ -> CCommentPrefix
+ / ({} {. (!Sps !CrtAction .)*})
+ -> UnknownAction
+ExpInAction <- Sp ({} Exp {})
+ -> ExpInAction
+
+Semicolon <- Sp ';'
+SimpleList <- {| Simple (Sp ',' Simple)* |}
+
+Do <- Sp ({}
+ 'do' Cut
+ {| (!END Action)* |}
+ NeedEnd)
+ -> Do
+
+Break <- Sp ({} BREAK {})
+ -> Break
+
+Return <- Sp ({} RETURN ReturnExpList {})
+ -> Return
+ReturnExpList
+ <- Sp {| Exp (Sp ',' MaybeExp)* |}
+ / Sp {| !Exp !',' |}
+ / ExpList
+
+Label <- Sp ({} LABEL MustName DirtyLabel {})
+ -> Label
+
+GoTo <- Sp ({} GOTO MustName {})
+ -> GoTo
+
+If <- Sp ({} {| IfHead IfBody* |} NeedEnd)
+ -> If
+
+IfHead <- Sp (IfPart {}) -> IfBlock
+ / Sp (ElseIfPart {}) -> ElseIfBlock
+ / Sp (ElsePart {}) -> ElseBlock
+IfBody <- Sp (ElseIfPart {}) -> ElseIfBlock
+ / Sp (ElsePart {}) -> ElseBlock
+IfPart <- IF DirtyExp NeedThen
+ {| (!ELSEIF !ELSE !END Action)* |}
+ElseIfPart <- ELSEIF DirtyExp NeedThen
+ {| (!ELSEIF !ELSE !END Action)* |}
+ElsePart <- ELSE
+ {| (!ELSEIF !ELSE !END Action)* |}
+
+For <- Loop / In
+
+Loop <- LoopBody
+ -> Loop
+LoopBody <- FOR LoopArgs NeedDo
+ {} {| (!END Action)* |}
+ NeedEnd
+LoopArgs <- MustName AssignOrEQ
+ ({} {| (COMMA / !DO !END Exp)* |} {})
+ -> PackLoopArgs
+
+In <- InBody
+ -> In
+InBody <- FOR InNameList NeedIn InExpList NeedDo
+ {} {| (!END Action)* |}
+ NeedEnd
+InNameList <- ({} {| (COMMA / !IN !DO !END Name)* |} {})
+ -> PackInNameList
+InExpList <- ({} {| (COMMA / !DO !DO !END Exp)* |} {})
+ -> PackInExpList
+
+While <- WhileBody
+ -> While
+WhileBody <- WHILE DirtyExp NeedDo
+ {| (!END Action)* |}
+ NeedEnd
+
+Repeat <- (RepeatBody {})
+ -> Repeat
+RepeatBody <- REPEAT
+ {| (!UNTIL Action)* |}
+ NeedUntil DirtyExp
+
+LocalAttr <- {| (Sp '<' Sp MustName Sp LocalAttrEnd)+ |}
+ -> LocalAttr
+LocalAttrEnd<- '>' / {} -> MissGT
+Local <- Sp ({} LOCAL LocalNameList ((AssignOrEQ ExpList) / %nil) {})
+ -> Local
+Set <- Sp ({} SimpleList AssignOrEQ ExpList {})
+ -> Set
+LocalNameList
+ <- {| LocalName (Sp ',' LocalName)* |}
+LocalName <- (MustName LocalAttr?)
+ -> LocalName
+
+Call <- Simple
+ -> SimpleCall
+
+LocalFunction
+ <- Sp ({} LOCAL FunctionNamedBody)
+ -> LocalFunction
+
+NamedFunction
+ <- FunctionNamedBody
+ -> NamedFunction
+FunctionNamedBody
+ <- FUNCTION FuncName FuncArgs
+ {| (!END Action)* |}
+ NeedEnd
+ / FUNCTION FuncName FuncArgsMiss
+ {| %nil |}
+ NeedEnd
+FuncName <- {| Single (Sp SuffixWithoutCall)* |}
+ -> Simple
+ / {} -> MissName %nil
+]]
+
+grammar 'Lua' [[
+Lua <- Head?
+ ({} {| Action* |} {}) -> Lua
+ Sp
+Head <- '#' (!%nl .)*
+]]
+
+return function (self, lua, mode)
+ local gram = compiled[mode] or compiled['Lua']
+ local r, _, pos = gram:match(lua)
+ if not r then
+ local err = errorpos(pos)
+ return nil, err
+ end
+
+ return r
+end
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
new file mode 100644
index 00000000..6ef239f1
--- /dev/null
+++ b/script/parser/guide.lua
@@ -0,0 +1,3884 @@
+local util = require 'utility'
+local error = error
+local type = type
+local next = next
+local tostring = tostring
+local print = print
+local ipairs = ipairs
+local tableInsert = table.insert
+local tableUnpack = table.unpack
+local tableRemove = table.remove
+local tableMove = table.move
+local tableSort = table.sort
+local tableConcat = table.concat
+local mathType = math.type
+local pairs = pairs
+local setmetatable = setmetatable
+local assert = assert
+local select = select
+local osClock = os.clock
+local DEVELOP = _G.DEVELOP
+local log = log
+local _G = _G
+
+local function logWarn(...)
+ log.warn(...)
+end
+
+_ENV = nil
+
+local m = {}
+
+local blockTypes = {
+ ['while'] = true,
+ ['in'] = true,
+ ['loop'] = true,
+ ['repeat'] = true,
+ ['do'] = true,
+ ['function'] = true,
+ ['ifblock'] = true,
+ ['elseblock'] = true,
+ ['elseifblock'] = true,
+ ['main'] = true,
+}
+
+local breakBlockTypes = {
+ ['while'] = true,
+ ['in'] = true,
+ ['loop'] = true,
+ ['repeat'] = true,
+}
+
+m.childMap = {
+ ['main'] = {'#', 'docs'},
+ ['repeat'] = {'#', 'filter'},
+ ['while'] = {'filter', '#'},
+ ['in'] = {'keys', '#'},
+ ['loop'] = {'loc', 'max', 'step', '#'},
+ ['if'] = {'#'},
+ ['ifblock'] = {'filter', '#'},
+ ['elseifblock'] = {'filter', '#'},
+ ['elseblock'] = {'#'},
+ ['setfield'] = {'node', 'field', 'value'},
+ ['setglobal'] = {'value'},
+ ['local'] = {'attrs', 'value'},
+ ['setlocal'] = {'value'},
+ ['return'] = {'#'},
+ ['do'] = {'#'},
+ ['select'] = {'vararg'},
+ ['table'] = {'#'},
+ ['tableindex'] = {'index', 'value'},
+ ['tablefield'] = {'field', 'value'},
+ ['function'] = {'args', '#'},
+ ['funcargs'] = {'#'},
+ ['setmethod'] = {'node', 'method', 'value'},
+ ['getmethod'] = {'node', 'method'},
+ ['setindex'] = {'node', 'index', 'value'},
+ ['getindex'] = {'node', 'index'},
+ ['paren'] = {'exp'},
+ ['call'] = {'node', 'args'},
+ ['callargs'] = {'#'},
+ ['getfield'] = {'node', 'field'},
+ ['list'] = {'#'},
+ ['binary'] = {1, 2},
+ ['unary'] = {1},
+
+ ['doc'] = {'#'},
+ ['doc.class'] = {'class', 'extends'},
+ ['doc.type'] = {'#types', '#enums', 'name'},
+ ['doc.alias'] = {'alias', 'extends'},
+ ['doc.param'] = {'param', 'extends'},
+ ['doc.return'] = {'#returns'},
+ ['doc.field'] = {'field', 'extends'},
+ ['doc.generic'] = {'#generics'},
+ ['doc.generic.object'] = {'generic', 'extends'},
+ ['doc.vararg'] = {'vararg'},
+ ['doc.type.table'] = {'key', 'value'},
+ ['doc.type.function'] = {'#args', '#returns'},
+ ['doc.overload'] = {'overload'},
+}
+
+m.actionMap = {
+ ['main'] = {'#'},
+ ['repeat'] = {'#'},
+ ['while'] = {'#'},
+ ['in'] = {'#'},
+ ['loop'] = {'#'},
+ ['if'] = {'#'},
+ ['ifblock'] = {'#'},
+ ['elseifblock'] = {'#'},
+ ['elseblock'] = {'#'},
+ ['do'] = {'#'},
+ ['function'] = {'#'},
+ ['funcargs'] = {'#'},
+}
+
+local TypeSort = {
+ ['boolean'] = 1,
+ ['string'] = 2,
+ ['integer'] = 3,
+ ['number'] = 4,
+ ['table'] = 5,
+ ['function'] = 6,
+ ['nil'] = 999,
+}
+
+local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
+
+--- 是否是字面量
+function m.isLiteral(obj)
+ local tp = obj.type
+ return tp == 'nil'
+ or tp == 'boolean'
+ or tp == 'string'
+ or tp == 'number'
+ or tp == 'table'
+end
+
+--- 获取字面量
+function m.getLiteral(obj)
+ local tp = obj.type
+ if tp == 'boolean' then
+ return obj[1]
+ elseif tp == 'string' then
+ return obj[1]
+ elseif tp == 'number' then
+ return obj[1]
+ end
+ return nil
+end
+
+--- 寻找父函数
+function m.getParentFunction(obj)
+ for _ = 1, 1000 do
+ obj = obj.parent
+ if not obj then
+ break
+ end
+ local tp = obj.type
+ if tp == 'function' or tp == 'main' then
+ return obj
+ end
+ end
+ return nil
+end
+
+--- 寻找所在区块
+function m.getBlock(obj)
+ for _ = 1, 1000 do
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if blockTypes[tp] then
+ return obj
+ end
+ obj = obj.parent
+ end
+ error('guide.getBlock overstack')
+end
+
+--- 寻找所在父区块
+function m.getParentBlock(obj)
+ for _ = 1, 1000 do
+ obj = obj.parent
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if blockTypes[tp] then
+ return obj
+ end
+ end
+ error('guide.getParentBlock overstack')
+end
+
+--- 寻找所在可break的父区块
+function m.getBreakBlock(obj)
+ for _ = 1, 1000 do
+ obj = obj.parent
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if breakBlockTypes[tp] then
+ return obj
+ end
+ if tp == 'function' then
+ return nil
+ end
+ end
+ error('guide.getBreakBlock overstack')
+end
+
+--- 寻找doc的主体
+function m.getDocState(obj)
+ for _ = 1, 1000 do
+ local parent = obj.parent
+ if not parent then
+ return obj
+ end
+ if parent.type == 'doc' then
+ return obj
+ end
+ obj = parent
+ end
+ error('guide.getDocState overstack')
+end
+
+--- 寻找所在父类型
+function m.getParentType(obj, want)
+ for _ = 1, 1000 do
+ obj = obj.parent
+ if not obj then
+ return nil
+ end
+ if want == obj.type then
+ return obj
+ end
+ end
+ error('guide.getParentType overstack')
+end
+
+--- 寻找根区块
+function m.getRoot(obj)
+ for _ = 1, 1000 do
+ if obj.type == 'main' then
+ return obj
+ end
+ local parent = obj.parent
+ if not parent then
+ return nil
+ end
+ obj = parent
+ end
+ error('guide.getRoot overstack')
+end
+
+function m.getUri(obj)
+ if obj.uri then
+ return obj.uri
+ end
+ local root = m.getRoot(obj)
+ if root then
+ return root.uri
+ end
+ return ''
+end
+
+function m.getENV(source, start)
+ if not start then
+ start = 1
+ end
+ return m.getLocal(source, '_ENV', start)
+ or m.getLocal(source, '@fenv', start)
+end
+
+--- 寻找函数的不定参数,返回不定参在第几个参数上,以及该参数对象。
+--- 如果函数是主函数,则返回`0, nil`。
+---@return table
+---@return integer
+function m.getFunctionVarArgs(func)
+ if func.type == 'main' then
+ return 0, nil
+ end
+ if func.type ~= 'function' then
+ return nil, nil
+ end
+ local args = func.args
+ if not args then
+ return nil, nil
+ end
+ for i = 1, #args do
+ local arg = args[i]
+ if arg.type == '...' then
+ return i, arg
+ end
+ end
+ return nil, nil
+end
+
+--- 获取指定区块中可见的局部变量
+---@param block table
+---@param name string {comment = '变量名'}
+---@param pos integer {comment = '可见位置'}
+function m.getLocal(block, name, pos)
+ block = m.getBlock(block)
+ for _ = 1, 1000 do
+ if not block then
+ return nil
+ end
+ local locals = block.locals
+ local res
+ if not locals then
+ goto CONTINUE
+ end
+ for i = 1, #locals do
+ local loc = locals[i]
+ if loc.effect > pos then
+ break
+ end
+ if loc[1] == name then
+ if not res or res.effect < loc.effect then
+ res = loc
+ end
+ end
+ end
+ if res then
+ return res, res
+ end
+ ::CONTINUE::
+ block = m.getParentBlock(block)
+ end
+ error('guide.getLocal overstack')
+end
+
+--- 获取指定区块中所有的可见局部变量名称
+function m.getVisibleLocals(block, pos)
+ local result = {}
+ m.eachSourceContain(m.getRoot(block), pos, function (source)
+ local locals = source.locals
+ if locals then
+ for i = 1, #locals do
+ local loc = locals[i]
+ local name = loc[1]
+ if loc.effect <= pos then
+ result[name] = loc
+ end
+ end
+ end
+ end)
+ return result
+end
+
+--- 获取指定区块中可见的标签
+---@param block table
+---@param name string {comment = '标签名'}
+function m.getLabel(block, name)
+ block = m.getBlock(block)
+ for _ = 1, 1000 do
+ if not block then
+ return nil
+ end
+ local labels = block.labels
+ if labels then
+ local label = labels[name]
+ if label then
+ return label
+ end
+ end
+ if block.type == 'function' then
+ return nil
+ end
+ block = m.getParentBlock(block)
+ end
+ error('guide.getLocal overstack')
+end
+
+function m.getStartFinish(source)
+ local start = source.start
+ local finish = source.finish
+ if not start then
+ local first = source[1]
+ if not first then
+ return nil, nil
+ end
+ local last = source[#source]
+ start = first.start
+ finish = last.finish
+ end
+ return start, finish
+end
+
+function m.getRange(source)
+ local start = source.vstart or source.start
+ local finish = source.range or source.finish
+ if not start then
+ local first = source[1]
+ if not first then
+ return nil, nil
+ end
+ local last = source[#source]
+ start = first.vstart or first.start
+ finish = last.range or last.finish
+ end
+ return start, finish
+end
+
+--- 判断source是否包含offset
+function m.isContain(source, offset)
+ local start, finish = m.getStartFinish(source)
+ if not start then
+ return false
+ end
+ return start <= offset and finish >= offset - 1
+end
+
+--- 判断offset在source的影响范围内
+---
+--- 主要针对赋值等语句时,key包含value
+function m.isInRange(source, offset)
+ local start, finish = m.getRange(source)
+ if not start then
+ return false
+ end
+ return start <= offset and finish >= offset - 1
+end
+
+function m.isBetween(source, tStart, tFinish)
+ local start, finish = m.getStartFinish(source)
+ if not start then
+ return false
+ end
+ return start <= tFinish and finish >= tStart - 1
+end
+
+function m.isBetweenRange(source, tStart, tFinish)
+ local start, finish = m.getRange(source)
+ if not start then
+ return false
+ end
+ return start <= tFinish and finish >= tStart - 1
+end
+
+--- 添加child
+function m.addChilds(list, obj, map)
+ local keys = map[obj.type]
+ if keys then
+ for i = 1, #keys do
+ local key = keys[i]
+ if key == '#' then
+ for i = 1, #obj do
+ list[#list+1] = obj[i]
+ end
+ elseif obj[key] then
+ list[#list+1] = obj[key]
+ elseif type(key) == 'string'
+ and key:sub(1, 1) == '#' then
+ key = key:sub(2)
+ for i = 1, #obj[key] do
+ list[#list+1] = obj[key][i]
+ end
+ end
+ end
+ end
+end
+
+--- 遍历所有包含offset的source
+function m.eachSourceContain(ast, offset, callback)
+ local list = { ast }
+ local mark = {}
+ while true do
+ local len = #list
+ if len == 0 then
+ return
+ end
+ local obj = list[len]
+ list[len] = nil
+ if not mark[obj] then
+ mark[obj] = true
+ if m.isInRange(obj, offset) then
+ if m.isContain(obj, offset) then
+ local res = callback(obj)
+ if res ~= nil then
+ return res
+ end
+ end
+ m.addChilds(list, obj, m.childMap)
+ end
+ end
+ end
+end
+
+--- 遍历所有在某个范围内的source
+function m.eachSourceBetween(ast, start, finish, callback)
+ local list = { ast }
+ local mark = {}
+ while true do
+ local len = #list
+ if len == 0 then
+ return
+ end
+ local obj = list[len]
+ list[len] = nil
+ if not mark[obj] then
+ mark[obj] = true
+ if m.isBetweenRange(obj, start, finish) then
+ if m.isBetween(obj, start, finish) then
+ local res = callback(obj)
+ if res ~= nil then
+ return res
+ end
+ end
+ m.addChilds(list, obj, m.childMap)
+ end
+ end
+ end
+end
+
+--- 遍历所有指定类型的source
+function m.eachSourceType(ast, type, callback)
+ local cache = ast.typeCache
+ if not cache then
+ cache = {}
+ ast.typeCache = cache
+ m.eachSource(ast, function (source)
+ local tp = source.type
+ if not tp then
+ return
+ end
+ local myCache = cache[tp]
+ if not myCache then
+ myCache = {}
+ cache[tp] = myCache
+ end
+ myCache[#myCache+1] = source
+ end)
+ end
+ local myCache = cache[type]
+ if not myCache then
+ return
+ end
+ for i = 1, #myCache do
+ callback(myCache[i])
+ end
+end
+
+--- 遍历所有的source
+function m.eachSource(ast, callback)
+ local list = { ast }
+ local mark = {}
+ local index = 1
+ while true do
+ local obj = list[index]
+ if not obj then
+ return
+ end
+ list[index] = false
+ index = index + 1
+ if not mark[obj] then
+ mark[obj] = true
+ callback(obj)
+ m.addChilds(list, obj, m.childMap)
+ end
+ end
+end
+
+--- 获取指定的 special
+function m.eachSpecialOf(ast, name, callback)
+ local root = m.getRoot(ast)
+ if not root.specials then
+ return
+ end
+ local specials = root.specials[name]
+ if not specials then
+ return
+ end
+ for i = 1, #specials do
+ callback(specials[i])
+ end
+end
+
+--- 获取偏移对应的坐标
+---@param lines table
+---@return integer {name = 'row'}
+---@return integer {name = 'col'}
+function m.positionOf(lines, offset)
+ if offset < 1 then
+ return 0, 0
+ end
+ local lastLine = lines[#lines]
+ if offset > lastLine.finish then
+ return #lines, lastLine.finish - lastLine.start + 1
+ end
+ local min = 1
+ local max = #lines
+ for _ = 1, 100 do
+ if max <= min then
+ local line = lines[min]
+ return min, offset - line.start + 1
+ end
+ local row = (max - min) // 2 + min
+ local line = lines[row]
+ if offset < line.start then
+ max = row - 1
+ elseif offset > line.finish then
+ min = row + 1
+ else
+ return row, offset - line.start + 1
+ end
+ end
+ error('Stack overflow!')
+end
+
+--- 获取坐标对应的偏移
+---@param lines table
+---@param row integer
+---@param col integer
+---@return integer {name = 'offset'}
+function m.offsetOf(lines, row, col)
+ if row < 1 then
+ return 0
+ end
+ if row > #lines then
+ local lastLine = lines[#lines]
+ return lastLine.finish
+ end
+ local line = lines[row]
+ local len = line.finish - line.start + 1
+ if col < 0 then
+ return line.start
+ elseif col > len then
+ return line.finish
+ else
+ return line.start + col - 1
+ end
+end
+
+function m.lineContent(lines, text, row, ignoreNL)
+ local line = lines[row]
+ if not line then
+ return ''
+ end
+ if ignoreNL then
+ return text:sub(line.start, line.range)
+ else
+ return text:sub(line.start, line.finish)
+ end
+end
+
+function m.lineRange(lines, row, ignoreNL)
+ local line = lines[row]
+ if not line then
+ return 0, 0
+ end
+ if ignoreNL then
+ return line.start, line.range
+ else
+ return line.start, line.finish
+ end
+end
+
+function m.getNameOfLiteral(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'string' then
+ return obj[1]
+ end
+ return nil
+end
+
+function m.getName(obj)
+ local tp = obj.type
+ if tp == 'getglobal'
+ or tp == 'setglobal' then
+ return obj[1]
+ elseif tp == 'local'
+ or tp == 'getlocal'
+ or tp == 'setlocal' then
+ return obj[1]
+ elseif tp == 'getfield'
+ or tp == 'setfield'
+ or tp == 'tablefield' then
+ return obj.field and obj.field[1]
+ elseif tp == 'getmethod'
+ or tp == 'setmethod' then
+ return obj.method and obj.method[1]
+ elseif tp == 'getindex'
+ or tp == 'setindex'
+ or tp == 'tableindex' then
+ return m.getNameOfLiteral(obj.index)
+ elseif tp == 'field'
+ or tp == 'method' then
+ return obj[1]
+ elseif tp == 'doc.class' then
+ return obj.class[1]
+ elseif tp == 'doc.alias' then
+ return obj.alias[1]
+ elseif tp == 'doc.field' then
+ return obj.field[1]
+ end
+ return m.getNameOfLiteral(obj)
+end
+
+function m.getKeyNameOfLiteral(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'field'
+ or tp == 'method' then
+ return 's|' .. obj[1]
+ elseif tp == 'string' then
+ local s = obj[1]
+ if s then
+ return 's|' .. s
+ else
+ return 's'
+ end
+ elseif tp == 'number' then
+ local n = obj[1]
+ if n then
+ return ('n|%s'):format(util.viewLiteral(obj[1]))
+ else
+ return 'n'
+ end
+ elseif tp == 'boolean' then
+ local b = obj[1]
+ if b then
+ return 'b|' .. tostring(b)
+ else
+ return 'b'
+ end
+ end
+ return nil
+end
+
+function m.getKeyName(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'getglobal'
+ or tp == 'setglobal' then
+ return 's|' .. obj[1]
+ elseif tp == 'local'
+ or tp == 'getlocal'
+ or tp == 'setlocal' then
+ return 'l|' .. obj[1]
+ elseif tp == 'getfield'
+ or tp == 'setfield'
+ or tp == 'tablefield' then
+ if obj.field then
+ return 's|' .. obj.field[1]
+ end
+ elseif tp == 'getmethod'
+ or tp == 'setmethod' then
+ if obj.method then
+ return 's|' .. obj.method[1]
+ end
+ elseif tp == 'getindex'
+ or tp == 'setindex'
+ or tp == 'tableindex' then
+ return m.getKeyNameOfLiteral(obj.index)
+ elseif tp == 'field'
+ or tp == 'method' then
+ return 's|' .. obj[1]
+ elseif tp == 'doc.class' then
+ return 's|' .. obj.class[1]
+ elseif tp == 'doc.alias' then
+ return 's|' .. obj.alias[1]
+ elseif tp == 'doc.field' then
+ return 's|' .. obj.field[1]
+ end
+ return m.getKeyNameOfLiteral(obj)
+end
+
+function m.getSimpleName(obj)
+ if obj.type == 'call' then
+ local key = obj.args and obj.args[2]
+ return m.getKeyName(key)
+ elseif obj.type == 'table' then
+ return ('t|%p'):format(obj)
+ elseif obj.type == 'select' then
+ return ('v|%p'):format(obj)
+ elseif obj.type == 'string' then
+ return ('z|%p'):format(obj)
+ elseif obj.type == 'doc.class.name'
+ or obj.type == 'doc.type.name' then
+ return ('c|%s'):format(obj[1])
+ elseif obj.type == 'doc.class' then
+ return ('c|%s'):format(obj.class[1])
+ end
+ return m.getKeyName(obj)
+end
+
+--- 测试 a 到 b 的路径(不经过函数,不考虑 goto),
+--- 每个路径是一个 block 。
+---
+--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>`
+---
+--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>`
+---
+--- 否则返回 `false`
+---
+--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。
+---@param a table
+---@param b table
+---@return string|boolean mode
+---@return table pathA?
+---@return table pathB?
+function m.getPath(a, b, sameFunction)
+ --- 首先测试双方在同一个函数内
+ if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then
+ return false
+ end
+ local mode
+ local objA
+ local objB
+ if a.finish < b.start then
+ mode = 'before'
+ objA = a
+ objB = b
+ elseif a.start > b.finish then
+ mode = 'after'
+ objA = b
+ objB = a
+ else
+ return 'equal', {}, {}
+ end
+ local pathA = {}
+ local pathB = {}
+ for _ = 1, 1000 do
+ objA = m.getParentBlock(objA)
+ pathA[#pathA+1] = objA
+ if (not sameFunction and objA.type == 'function') or objA.type == 'main' then
+ break
+ end
+ end
+ for _ = 1, 1000 do
+ objB = m.getParentBlock(objB)
+ pathB[#pathB+1] = objB
+ if (not sameFunction and objA.type == 'function') or objB.type == 'main' then
+ break
+ end
+ end
+ -- pathA: {1, 2, 3, 4, 5}
+ -- pathB: {5, 6, 2, 3}
+ local top = #pathB
+ local start
+ for i = #pathA, 1, -1 do
+ local currentBlock = pathA[i]
+ if currentBlock == pathB[top] then
+ start = i
+ break
+ end
+ end
+ if not start then
+ return nil
+ end
+ -- pathA: { 1, 2, 3}
+ -- pathB: {5, 6, 2, 3}
+ local extra = 0
+ local align = top - start
+ for i = start, 1, -1 do
+ local currentA = pathA[i]
+ local currentB = pathB[i+align]
+ if currentA ~= currentB then
+ extra = i
+ break
+ end
+ end
+ -- pathA: {1}
+ local resultA = {}
+ for i = extra, 1, -1 do
+ resultA[#resultA+1] = pathA[i]
+ end
+ -- pathB: {5, 6}
+ local resultB = {}
+ for i = extra + align, 1, -1 do
+ resultB[#resultB+1] = pathB[i]
+ end
+ return mode, resultA, resultB
+end
+
+-- 根据语法,单步搜索定义
+local function stepRefOfLocal(loc, mode)
+ local results = {}
+ if loc.start ~= 0 then
+ results[#results+1] = loc
+ end
+ local refs = loc.ref
+ if not refs then
+ return results
+ end
+ for i = 1, #refs do
+ local ref = refs[i]
+ if ref.start == 0 then
+ goto CONTINUE
+ end
+ if mode == 'def' then
+ if ref.type == 'local'
+ or ref.type == 'setlocal' then
+ results[#results+1] = ref
+ end
+ else
+ if ref.type == 'local'
+ or ref.type == 'setlocal'
+ or ref.type == 'getlocal' then
+ results[#results+1] = ref
+ end
+ end
+ ::CONTINUE::
+ end
+ return results
+end
+
+local function stepRefOfLabel(label, mode)
+ local results = { label }
+ if not label or mode == 'def' then
+ return results
+ end
+ local refs = label.ref
+ for i = 1, #refs do
+ local ref = refs[i]
+ results[#results+1] = ref
+ end
+ return results
+end
+
+local function stepRefOfDocType(status, obj, mode)
+ local results = {}
+ if obj.type == 'doc.class.name'
+ or obj.type == 'doc.type.name'
+ or obj.type == 'doc.alias.name'
+ or obj.type == 'doc.extends.name' then
+ local name = obj[1]
+ if not name or not status.interface.docType then
+ return results
+ end
+ local docs = status.interface.docType(name)
+ for i = 1, #docs do
+ local doc = docs[i]
+ if mode == 'def' then
+ if doc.type == 'doc.class.name'
+ or doc.type == 'doc.alias.name' then
+ results[#results+1] = doc
+ end
+ else
+ results[#results+1] = doc
+ end
+ end
+ else
+ results[#results+1] = obj
+ end
+ return results
+end
+
+function m.getStepRef(status, obj, mode)
+ if obj.type == 'getlocal'
+ or obj.type == 'setlocal' then
+ return stepRefOfLocal(obj.node, mode)
+ end
+ if obj.type == 'local' then
+ return stepRefOfLocal(obj, mode)
+ end
+ if obj.type == 'label' then
+ return stepRefOfLabel(obj, mode)
+ end
+ if obj.type == 'goto' then
+ return stepRefOfLabel(obj.node, mode)
+ end
+ if obj.type == 'doc.class.name'
+ or obj.type == 'doc.type.name'
+ or obj.type == 'doc.extends.name'
+ or obj.type == 'doc.alias.name' then
+ return stepRefOfDocType(status, obj, mode)
+ end
+ return nil
+end
+
+-- 根据语法,单步搜索field
+local function stepFieldOfLocal(loc)
+ local results = {}
+ local refs = loc.ref
+ for i = 1, #refs do
+ local ref = refs[i]
+ if ref.type == 'setglobal'
+ or ref.type == 'getglobal' then
+ results[#results+1] = ref
+ elseif ref.type == 'getlocal' then
+ local nxt = ref.next
+ if nxt then
+ if nxt.type == 'setfield'
+ or nxt.type == 'getfield'
+ or nxt.type == 'setmethod'
+ or nxt.type == 'getmethod'
+ or nxt.type == 'setindex'
+ or nxt.type == 'getindex' then
+ results[#results+1] = nxt
+ end
+ end
+ end
+ end
+ return results
+end
+local function stepFieldOfTable(tbl)
+ local result = {}
+ for i = 1, #tbl do
+ result[i] = tbl[i]
+ end
+ return result
+end
+function m.getStepField(obj)
+ if obj.type == 'getlocal'
+ or obj.type == 'setlocal' then
+ return stepFieldOfLocal(obj.node)
+ end
+ if obj.type == 'local' then
+ return stepFieldOfLocal(obj)
+ end
+ if obj.type == 'table' then
+ return stepFieldOfTable(obj)
+ end
+end
+
+local function convertSimpleList(list)
+ local simple = {}
+ for i = #list, 1, -1 do
+ local c = list[i]
+ if c.type == 'getglobal'
+ or c.type == 'setglobal' then
+ if c.special == '_G' then
+ simple.mode = 'global'
+ goto CONTINUE
+ end
+ local loc = c.node
+ if loc.special == '_G' then
+ simple.mode = 'global'
+ if not simple.node then
+ simple.node = c
+ end
+ else
+ simple.mode = 'local'
+ simple[#simple+1] = m.getSimpleName(loc)
+ if not simple.node then
+ simple.node = loc
+ end
+ end
+ elseif c.type == 'getlocal'
+ or c.type == 'setlocal' then
+ if c.special == '_G' then
+ simple.mode = 'global'
+ goto CONTINUE
+ end
+ simple.mode = 'local'
+ if not simple.node then
+ simple.node = c.node
+ end
+ elseif c.type == 'local' then
+ simple.mode = 'local'
+ if not simple.node then
+ simple.node = c
+ end
+ else
+ if not simple.node then
+ simple.node = c
+ end
+ end
+ simple[#simple+1] = m.getSimpleName(c)
+ ::CONTINUE::
+ end
+ if simple.mode == 'global' and #simple == 0 then
+ simple[1] = 's|_G'
+ simple.node = list[#list]
+ end
+ return simple
+end
+
+-- 搜索 `a.b.c` 的等价表达式
+local function buildSimpleList(obj, max)
+ local list = {}
+ local cur = obj
+ local limit = max and (max + 1) or 11
+ for i = 1, max or limit do
+ if i == limit then
+ return nil
+ end
+ while cur.type == 'paren' do
+ cur = cur.exp
+ if not cur then
+ return nil
+ end
+ end
+ if cur.type == 'setfield'
+ or cur.type == 'getfield'
+ or cur.type == 'setmethod'
+ or cur.type == 'getmethod'
+ or cur.type == 'setindex'
+ or cur.type == 'getindex' then
+ list[i] = cur
+ cur = cur.node
+ elseif cur.type == 'tablefield'
+ or cur.type == 'tableindex' then
+ list[i] = cur
+ cur = cur.parent.parent
+ if cur.type == 'return' then
+ list[i+1] = list[i].parent
+ break
+ end
+ elseif cur.type == 'getlocal'
+ or cur.type == 'setlocal'
+ or cur.type == 'local' then
+ list[i] = cur
+ break
+ elseif cur.type == 'setglobal'
+ or cur.type == 'getglobal' then
+ list[i] = cur
+ break
+ elseif cur.type == 'select'
+ or cur.type == 'table' then
+ list[i] = cur
+ break
+ elseif cur.type == 'string' then
+ list[i] = cur
+ break
+ elseif cur.type == 'doc.class.name'
+ or cur.type == 'doc.type.name'
+ or cur.type == 'doc.class' then
+ list[i] = cur
+ break
+ elseif cur.type == 'function'
+ or cur.type == 'main' then
+ break
+ else
+ return nil
+ end
+ end
+ return convertSimpleList(list)
+end
+
+function m.getSimple(obj, max)
+ local simpleList
+ if obj.type == 'getfield'
+ or obj.type == 'setfield'
+ or obj.type == 'getmethod'
+ or obj.type == 'setmethod'
+ or obj.type == 'getindex'
+ or obj.type == 'setindex'
+ or obj.type == 'local'
+ or obj.type == 'getlocal'
+ or obj.type == 'setlocal'
+ or obj.type == 'setglobal'
+ or obj.type == 'getglobal'
+ or obj.type == 'tablefield'
+ or obj.type == 'tableindex'
+ or obj.type == 'select'
+ or obj.type == 'table'
+ or obj.type == 'string'
+ or obj.type == 'doc.class.name'
+ or obj.type == 'doc.class'
+ or obj.type == 'doc.type.name' then
+ simpleList = buildSimpleList(obj, max)
+ elseif obj.type == 'field'
+ or obj.type == 'method' then
+ simpleList = buildSimpleList(obj.parent, max)
+ end
+ return simpleList
+end
+
+function m.status(parentStatus, interface)
+ local status = {
+ cache = parentStatus and parentStatus.cache or {
+ count = 0,
+ },
+ depth = parentStatus and (parentStatus.depth + 1) or 1,
+ interface = parentStatus and parentStatus.interface or {},
+ locks = parentStatus and parentStatus.locks or {},
+ deep = parentStatus and parentStatus.deep,
+ results = {},
+ }
+ status.lock = status.locks[status.depth] or {}
+ status.locks[status.depth] = status.lock
+ if interface then
+ for k, v in pairs(interface) do
+ status.interface[k] = v
+ end
+ end
+ local searchDepth = status.interface.getSearchDepth and status.interface.getSearchDepth() or 0
+ if status.depth >= searchDepth then
+ status.deep = false
+ end
+ return status
+end
+
+function m.copyStatusResults(a, b)
+ local ra = a.results
+ local rb = b.results
+ for i = 1, #rb do
+ ra[#ra+1] = rb[i]
+ end
+end
+
+function m.isGlobal(source)
+ if source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ if source.node and source.node.tag == '_ENV' then
+ return true
+ end
+ end
+ if source.type == 'field' then
+ source = source.parent
+ end
+ if source.type == 'getfield'
+ or source.type == 'setfield' then
+ local node = source.node
+ if node and node.special == '_G' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.isDoc(source)
+ return source.type:sub(1, 4) == 'doc.'
+end
+
+--- 根据函数的调用参数,获取:调用,参数索引
+function m.getCallAndArgIndex(callarg)
+ local callargs = callarg.parent
+ if not callargs or callargs.type ~= 'callargs' then
+ return nil
+ end
+ local index
+ for i = 1, #callargs do
+ if callargs[i] == callarg then
+ index = i
+ break
+ end
+ end
+ local call = callargs.parent
+ return call, index
+end
+
+--- 根据函数调用的返回值,获取:调用的函数,参数列表,自己是第几个返回值
+function m.getCallValue(source)
+ local value = m.getObjectValue(source) or source
+ if not value then
+ return
+ end
+ local call, index
+ if value.type == 'call' then
+ call = value
+ index = 1
+ elseif value.type == 'select' then
+ call = value.vararg
+ index = value.index
+ if call.type ~= 'call' then
+ return
+ end
+ else
+ return
+ end
+ return call.node, call.args, index
+end
+
+function m.getNextRef(ref)
+ local nextRef = ref.next
+ if nextRef then
+ if nextRef.type == 'setfield'
+ or nextRef.type == 'getfield'
+ or nextRef.type == 'setmethod'
+ or nextRef.type == 'getmethod'
+ or nextRef.type == 'setindex'
+ or nextRef.type == 'getindex' then
+ return nextRef
+ end
+ end
+ -- 穿透 rawget 与 rawset
+ local call, index = m.getCallAndArgIndex(ref)
+ if call then
+ if call.node.special == 'rawset' and index == 1 then
+ return call
+ end
+ if call.node.special == 'rawget' and index == 1 then
+ return call
+ end
+ end
+ -- doc.type.array
+ if ref.type == 'doc.type' then
+ local arrays = {}
+ for _, typeUnit in ipairs(ref.types) do
+ if typeUnit.type == 'doc.type.array' then
+ arrays[#arrays+1] = typeUnit.node
+ end
+ end
+ -- 返回一个 dummy
+ -- TODO 用弱表维护唯一性?
+ return {
+ type = 'doc.type',
+ start = ref.start,
+ finish = ref.finish,
+ types = arrays,
+ parent = ref.parent,
+ array = true,
+ enums = {},
+ resumes = {},
+ }
+ end
+
+ return nil
+end
+
+function m.checkSameSimpleInValueOfTable(status, value, start, queue)
+ if value.type ~= 'table' then
+ return
+ end
+ for i = 1, #value do
+ local field = value[i]
+ queue[#queue+1] = {
+ obj = field,
+ start = start + 1,
+ }
+ end
+end
+
+function m.searchFields(status, obj, key)
+ local simple = m.getSimple(obj)
+ if not simple then
+ return
+ end
+ simple[#simple+1] = key and ('s|' .. key) or '*'
+ m.searchSameFields(status, simple, 'field')
+ m.cleanResults(status.results)
+end
+
+function m.getObjectValue(obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return nil
+ end
+ end
+ if obj.type == 'boolean'
+ or obj.type == 'number'
+ or obj.type == 'integer'
+ or obj.type == 'string' then
+ return obj
+ end
+ if obj.value then
+ return obj.value
+ end
+ if obj.type == 'field'
+ or obj.type == 'method' then
+ return obj.parent.value
+ end
+ if obj.type == 'call' then
+ if obj.node.special == 'rawset' then
+ return obj.args[3]
+ end
+ end
+ if obj.type == 'select' then
+ return obj
+ end
+ return nil
+end
+
+function m.checkSameSimpleInValueInMetaTable(status, mt, start, queue)
+ local newStatus = m.status(status)
+ m.searchFields(newStatus, mt, '__index')
+ local refsStatus = m.status(status)
+ for i = 1, #newStatus.results do
+ local indexValue = m.getObjectValue(newStatus.results[i])
+ if indexValue then
+ m.searchRefs(refsStatus, indexValue, 'ref')
+ end
+ end
+ for i = 1, #refsStatus.results do
+ local obj = refsStatus.results[i]
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+end
+function m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue)
+ if not func or func.special ~= 'setmetatable' then
+ return
+ end
+ local call = func.parent
+ local args = call.args
+ local obj = args[1]
+ local mt = args[2]
+ if obj then
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+ if mt then
+ m.checkSameSimpleInValueInMetaTable(status, mt, start, queue)
+ end
+end
+
+function m.checkSameSimpleInValueOfCallMetaTable(status, call, start, queue)
+ if call.type == 'call' then
+ m.checkSameSimpleInValueOfSetMetaTable(status, call.node, start, queue)
+ end
+end
+
+function m.checkSameSimpleInSpecialBranch(status, obj, start, queue)
+ if not status.interface.index then
+ return
+ end
+ local results = status.interface.index(obj)
+ if not results then
+ return
+ end
+ for _, res in ipairs(results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start + 1,
+ }
+ end
+end
+
+function m.checkSameSimpleByDocType(status, doc)
+ if status.cache.searchingBindedDoc then
+ return
+ end
+ if doc.type ~= 'doc.type' then
+ return
+ end
+ local results = {}
+ for _, piece in ipairs(doc.types) do
+ local pieceResult = stepRefOfDocType(status, piece, 'def')
+ for _, res in ipairs(pieceResult) do
+ results[#results+1] = res
+ end
+ end
+ return results
+end
+
+function m.checkSameSimpleByBindDocs(status, obj, start, queue, mode)
+ if not obj.bindDocs then
+ return
+ end
+ if status.cache.searchingBindedDoc then
+ return
+ end
+ local skipInfer = false
+ local results = {}
+ for _, doc in ipairs(obj.bindDocs) do
+ if doc.type == 'doc.class' then
+ results[#results+1] = doc
+ elseif doc.type == 'doc.type' then
+ results[#results+1] = doc
+ elseif doc.type == 'doc.param' then
+ -- function (x) 的情况
+ if obj.type == 'local'
+ and m.getName(obj) == doc.param[1] then
+ if obj.parent.type == 'funcargs'
+ or obj.parent.type == 'in'
+ or obj.parent.type == 'loop' then
+ results[#results+1] = doc.extends
+ end
+ end
+ elseif doc.type == 'doc.field' then
+ results[#results+1] = doc
+ end
+ end
+ for _, res in ipairs(results) do
+ if res.type == 'doc.class'
+ or res.type == 'doc.type' then
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ skipInfer = true
+ end
+ if res.type == 'doc.type.function' then
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ elseif res.type == 'doc.field' then
+ queue[#queue+1] = {
+ obj = res,
+ start = start + 1,
+ }
+ end
+ end
+ return skipInfer
+end
+
+function m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode)
+ if status.cache.searchingBindedDoc then
+ return
+ end
+ if not obj.bindSources then
+ return
+ end
+ status.cache.searchingBindedDoc = true
+ local mark = {}
+ local newStatus = m.status(status)
+ for _, ref in ipairs(obj.bindSources) do
+ if not mark[ref] then
+ mark[ref] = true
+ m.searchRefs(newStatus, ref, mode)
+ end
+ end
+ status.cache.searchingBindedDoc = nil
+ for _, res in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+end
+
+function m.checkSameSimpleByDoc(status, obj, start, queue, mode)
+ if obj.type == 'doc.class.name'
+ or obj.type == 'doc.type.name' then
+ obj = m.getDocState(obj)
+ end
+ if obj.type == 'doc.class' then
+ local classStart
+ for _, doc in ipairs(obj.bindGroup) do
+ if doc == obj then
+ classStart = true
+ elseif doc.type == 'doc.class' then
+ classStart = false
+ end
+ if classStart and doc.type == 'doc.field' then
+ queue[#queue+1] = {
+ obj = doc,
+ start = start + 1,
+ }
+ end
+ end
+ m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode)
+ if mode == 'ref' then
+ local pieceResult = stepRefOfDocType(status, obj.class, 'ref')
+ for _, res in ipairs(pieceResult) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+ end
+ return true
+ elseif obj.type == 'doc.type' then
+ for _, piece in ipairs(obj.types) do
+ local pieceResult = stepRefOfDocType(status, piece, 'def')
+ for _, res in ipairs(pieceResult) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+ end
+ if mode == 'ref' then
+ m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode)
+ end
+ return true
+ elseif obj.type == 'doc.field' then
+ if mode ~= 'field' then
+ return m.checkSameSimpleByDoc(status, obj.extends, start, queue, mode)
+ end
+ end
+end
+
+function m.checkSameSimpleInArg1OfSetMetaTable(status, obj, start, queue)
+ local args = obj.parent
+ if not args or args.type ~= 'callargs' then
+ return
+ end
+ if args[1] ~= obj then
+ return
+ end
+ local mt = args[2]
+ if mt then
+ if m.checkValueMark(status, obj, mt) then
+ return
+ end
+ m.checkSameSimpleInValueInMetaTable(status, mt, start, queue)
+ end
+end
+
+function m.searchSameMethodCrossSelf(ref, mark)
+ local selfNode
+ if ref.tag == 'self' then
+ selfNode = ref
+ else
+ if ref.type == 'getlocal'
+ or ref.type == 'setlocal' then
+ local node = ref.node
+ if node.tag == 'self' then
+ selfNode = node
+ end
+ end
+ end
+ if selfNode then
+ if mark[selfNode] then
+ return nil
+ end
+ mark[selfNode] = true
+ return selfNode.method.node
+ end
+end
+
+function m.searchSameMethod(ref, mark)
+ if mark['method'] then
+ return nil
+ end
+ local nxt = ref.next
+ if not nxt then
+ return nil
+ end
+ if nxt.type == 'setmethod' then
+ mark['method'] = true
+ return ref
+ end
+ return nil
+end
+
+function m.searchSameFieldsCrossMethod(status, ref, start, queue)
+ local mark = status.cache.crossMethodMark
+ if not mark then
+ mark = {}
+ status.cache.crossMethodMark = mark
+ end
+ local method = m.searchSameMethod(ref, mark)
+ or m.searchSameMethodCrossSelf(ref, mark)
+ if not method then
+ return
+ end
+ local methodStatus = m.status(status)
+ m.searchRefs(methodStatus, method, 'ref')
+ for _, md in ipairs(methodStatus.results) do
+ queue[#queue+1] = {
+ obj = md,
+ start = start,
+ force = true,
+ }
+ local nxt = md.next
+ if not nxt then
+ goto CONTINUE
+ end
+ if nxt.type == 'setmethod' then
+ local func = nxt.value
+ if not func then
+ goto CONTINUE
+ end
+ local selfNode = func.locals and func.locals[1]
+ if not selfNode or not selfNode.ref then
+ goto CONTINUE
+ end
+ if mark[selfNode] then
+ goto CONTINUE
+ end
+ mark[selfNode] = true
+ for _, selfRef in ipairs(selfNode.ref) do
+ queue[#queue+1] = {
+ obj = selfRef,
+ start = start,
+ force = true,
+ }
+ end
+ end
+ ::CONTINUE::
+ end
+end
+
+local function checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, source, index, call)
+ if not source or source.type ~= 'function' then
+ return
+ end
+ if not source.bindDocs then
+ return
+ end
+ local returns = {}
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ returns[#returns+1] = rtn
+ end
+ end
+ end
+ local rtn = returns[index]
+ if not rtn then
+ return
+ end
+ local types = m.checkSameSimpleByDocType(status, rtn)
+ if not types then
+ return
+ end
+ for _, res in ipairs(types) do
+ results[#results+1] = res
+ end
+ return true
+end
+
+local function checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, source, index)
+ if not source.bindDocs then
+ return
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.type' then
+ for _, typeUnit in ipairs(doc.types) do
+ if typeUnit.type == 'doc.type.function' then
+ local rtn = typeUnit.returns[index]
+ if rtn then
+ local types = m.checkSameSimpleByDocType(status, rtn)
+ if types then
+ for _, res in ipairs(types) do
+ results[#results+1] = res
+ end
+ return true
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+function m.checkSameSimpleInCallInSameFile(status, func, args, index)
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, func, 'def')
+ local results = {}
+ for _, def in ipairs(newStatus.results) do
+ local hasDocReturn = checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, def, index)
+ or checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, def, index)
+ if not hasDocReturn then
+ local value = m.getObjectValue(def) or def
+ if value.type == 'function' then
+ local returns = value.returns
+ if returns then
+ for _, ret in ipairs(returns) do
+ local exp = ret[index]
+ if exp then
+ results[#results+1] = exp
+ end
+ end
+ end
+ end
+ end
+ end
+ return results
+end
+
+function m.checkSameSimpleInCall(status, ref, start, queue, mode)
+ local func, args, index = m.getCallValue(ref)
+ if not func then
+ return
+ end
+ if m.checkCallMark(status, func.parent, true) then
+ return
+ end
+ status.cache.crossCallCount = status.cache.crossCallCount or 0
+ if status.cache.crossCallCount >= 5 then
+ return
+ end
+ status.cache.crossCallCount = status.cache.crossCallCount + 1
+ -- 检查赋值是 semetatable() 的情况
+ m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue)
+ -- 检查赋值是 func() 的情况
+ local objs = m.checkSameSimpleInCallInSameFile(status, func, args, index)
+ if status.interface.call then
+ local cobjs = status.interface.call(func, args, index)
+ if cobjs then
+ for _, obj in ipairs(cobjs) do
+ if not m.checkReturnMark(status, obj) then
+ objs[#objs+1] = obj
+ end
+ end
+ end
+ end
+ m.cleanResults(objs)
+ local newStatus = m.status(status)
+ for _, obj in ipairs(objs) do
+ m.searchRefs(newStatus, obj, mode)
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+ status.cache.crossCallCount = status.cache.crossCallCount - 1
+ for _, obj in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+end
+
+local function searchRawset(ref, results)
+ if m.getKeyName(ref) ~= 's|rawset' then
+ return
+ end
+ local call = ref.parent
+ if call.type ~= 'call' or call.node ~= ref then
+ return
+ end
+ if not call.args then
+ return
+ end
+ local arg1 = call.args[1]
+ if arg1.special ~= '_G' then
+ -- 不会吧不会吧,不会真的有人写成 `rawset(_G._G._G, 'xxx', value)` 吧
+ return
+ end
+ results[#results+1] = call
+end
+
+local function searchG(ref, results)
+ while ref and m.getKeyName(ref) == 's|_G' do
+ results[#results+1] = ref
+ ref = ref.next
+ end
+ if ref then
+ results[#results+1] = ref
+ searchRawset(ref, results)
+ end
+end
+
+local function searchEnvRef(ref, results)
+ if ref.type == 'setglobal'
+ or ref.type == 'getglobal' then
+ results[#results+1] = ref
+ searchG(ref, results)
+ elseif ref.type == 'getlocal' then
+ results[#results+1] = ref.next
+ searchG(ref.next, results)
+ end
+end
+
+function m.findGlobals(ast)
+ local root = m.getRoot(ast)
+ local results = {}
+ local env = m.getENV(root)
+ if env.ref then
+ for _, ref in ipairs(env.ref) do
+ searchEnvRef(ref, results)
+ end
+ end
+ return results
+end
+
+function m.findGlobalsOfName(ast, name)
+ local root = m.getRoot(ast)
+ local results = {}
+ local globals = m.findGlobals(root)
+ for _, global in ipairs(globals) do
+ if m.getKeyName(global) == name then
+ results[#results+1] = global
+ end
+ end
+ return results
+end
+
+function m.checkSameSimpleInGlobal(status, name, source, start, queue)
+ if not name then
+ return
+ end
+ local objs
+ if status.interface.global then
+ objs = status.interface.global(name)
+ else
+ objs = m.findGlobalsOfName(source, name)
+ end
+ if objs then
+ for _, obj in ipairs(objs) do
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+ end
+end
+
+function m.checkValueMark(status, a, b)
+ if not status.cache.valueMark then
+ status.cache.valueMark = {}
+ end
+ if status.cache.valueMark[a]
+ or status.cache.valueMark[b] then
+ return true
+ end
+ status.cache.valueMark[a] = true
+ status.cache.valueMark[b] = true
+ return false
+end
+
+function m.checkCallMark(status, a, mark)
+ if not status.cache.callMark then
+ status.cache.callMark = {}
+ end
+ if mark then
+ status.cache.callMark[a] = mark
+ else
+ return status.cache.callMark[a]
+ end
+ return false
+end
+
+function m.checkReturnMark(status, a, mark)
+ if not status.cache.returnMark then
+ status.cache.returnMark = {}
+ end
+ if mark then
+ status.cache.returnMark[a] = mark
+ else
+ return status.cache.returnMark[a]
+ end
+ return false
+end
+
+function m.searchSameFieldsInValue(status, ref, start, queue, mode)
+ local value = m.getObjectValue(ref)
+ if not value then
+ return
+ end
+ if m.checkValueMark(status, ref, value) then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, value, mode)
+ for _, res in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+ queue[#queue+1] = {
+ obj = value,
+ start = start,
+ force = true,
+ }
+ -- 检查形如 a = f() 的分支情况
+ m.checkSameSimpleInCall(status, value, start, queue, mode)
+end
+
+function m.checkSameSimpleAsTableField(status, ref, start, queue)
+ if not status.deep then
+ --return
+ end
+ local parent = ref.parent
+ if not parent or parent.type ~= 'tablefield' then
+ return
+ end
+ if m.checkValueMark(status, parent, ref) then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, parent.field, 'ref')
+ for _, res in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+end
+
+function m.checkSearchLevel(status)
+ status.cache.back = status.cache.back or 0
+ if status.cache.back >= (status.interface.searchLevel or 0) then
+ -- TODO 限制向前搜索的次数
+ --return true
+ end
+ status.cache.back = status.cache.back + 1
+ return false
+end
+
+function m.checkSameSimpleAsReturn(status, ref, start, queue)
+ if not status.deep then
+ return
+ end
+ if not ref.parent or ref.parent.type ~= 'return' then
+ return
+ end
+ if ref.parent.parent.type ~= 'main' then
+ return
+ end
+ if m.checkSearchLevel(status) then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchRefsAsFunctionReturn(newStatus, ref, 'ref')
+ for _, res in ipairs(newStatus.results) do
+ if not m.checkCallMark(status, res) then
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+ end
+end
+
+function m.checkSameSimpleAsSetValue(status, ref, start, queue)
+ if ref.type == 'select' then
+ return
+ end
+ local parent = ref.parent
+ if not parent then
+ return
+ end
+ if m.getObjectValue(parent) ~= ref then
+ return
+ end
+ if m.checkValueMark(status, ref, parent) then
+ return
+ end
+ if m.checkSearchLevel(status) then
+ return
+ end
+ local obj
+ if parent.type == 'local'
+ or parent.type == 'setglobal'
+ or parent.type == 'setlocal' then
+ obj = parent
+ elseif parent.type == 'setfield' then
+ obj = parent.field
+ elseif parent.type == 'setmethod' then
+ obj = parent.method
+ end
+ if not obj then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, obj, 'ref')
+ for _, res in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+end
+
+function m.checkSameSimpleInString(status, ref, start, queue, mode)
+ -- 特殊处理 ('xxx').xxx 的形式
+ if ref.type ~= 'string' then
+ return
+ end
+ if not status.interface.docType then
+ return
+ end
+ if status.cache.searchingBindedDoc then
+ return
+ end
+ local newStatus = m.status(status)
+ local docs = status.interface.docType('string*')
+ local mark = {}
+ for i = 1, #docs do
+ local doc = docs[i]
+ m.searchFields(newStatus, doc)
+ end
+ for _, res in ipairs(newStatus.results) do
+ if mark[res] then
+ goto CONTINUE
+ end
+ mark[res] = true
+ queue[#queue+1] = {
+ obj = res,
+ start = start + 1,
+ }
+ ::CONTINUE::
+ end
+ return true
+end
+
+function m.pushResult(status, mode, ref, simple)
+ local results = status.results
+ if mode == 'def' then
+ if ref.type == 'setglobal'
+ or ref.type == 'setlocal'
+ or ref.type == 'local' then
+ results[#results+1] = ref
+ elseif ref.type == 'setfield'
+ or ref.type == 'tablefield' then
+ results[#results+1] = ref
+ elseif ref.type == 'setmethod' then
+ results[#results+1] = ref
+ elseif ref.type == 'setindex'
+ or ref.type == 'tableindex' then
+ results[#results+1] = ref
+ elseif ref.type == 'call' then
+ if ref.node.special == 'rawset' then
+ results[#results+1] = ref
+ end
+ elseif ref.type == 'function' then
+ results[#results+1] = ref
+ elseif ref.type == 'table' then
+ results[#results+1] = ref
+ elseif ref.type == 'doc.type.function'
+ or ref.type == 'doc.class.name'
+ or ref.type == 'doc.field' then
+ results[#results+1] = ref
+ end
+ if ref.parent and ref.parent.type == 'return' then
+ if m.getParentFunction(ref) ~= m.getParentFunction(simple.node) then
+ results[#results+1] = ref
+ end
+ end
+ elseif mode == 'ref' then
+ if ref.type == 'setfield'
+ or ref.type == 'getfield'
+ or ref.type == 'tablefield' then
+ results[#results+1] = ref
+ elseif ref.type == 'setmethod'
+ or ref.type == 'getmethod' then
+ results[#results+1] = ref
+ elseif ref.type == 'setindex'
+ or ref.type == 'getindex'
+ or ref.type == 'tableindex' then
+ results[#results+1] = ref
+ elseif ref.type == 'setglobal'
+ or ref.type == 'getglobal'
+ or ref.type == 'local'
+ or ref.type == 'setlocal'
+ or ref.type == 'getlocal' then
+ results[#results+1] = ref
+ elseif ref.type == 'function' then
+ results[#results+1] = ref
+ elseif ref.type == 'table' then
+ results[#results+1] = ref
+ elseif ref.type == 'call' then
+ if ref.node.special == 'rawset'
+ or ref.node.special == 'rawget' then
+ results[#results+1] = ref
+ end
+ elseif ref.type == 'doc.type.function'
+ or ref.type == 'doc.class.name'
+ or ref.type == 'doc.field' then
+ results[#results+1] = ref
+ end
+ if ref.parent and ref.parent.type == 'return' then
+ results[#results+1] = ref
+ end
+ elseif mode == 'field' then
+ if ref.type == 'setfield'
+ or ref.type == 'getfield'
+ or ref.type == 'tablefield' then
+ results[#results+1] = ref
+ elseif ref.type == 'setmethod'
+ or ref.type == 'getmethod' then
+ results[#results+1] = ref
+ elseif ref.type == 'setindex'
+ or ref.type == 'getindex'
+ or ref.type == 'tableindex' then
+ results[#results+1] = ref
+ elseif ref.type == 'setglobal'
+ or ref.type == 'getglobal' then
+ results[#results+1] = ref
+ elseif ref.type == 'function' then
+ results[#results+1] = ref
+ elseif ref.type == 'table' then
+ results[#results+1] = ref
+ elseif ref.type == 'call' then
+ if ref.node.special == 'rawset'
+ or ref.node.special == 'rawget' then
+ results[#results+1] = ref
+ end
+ elseif ref.type == 'doc.type.function'
+ or ref.type == 'doc.class.name'
+ or ref.type == 'doc.field' then
+ results[#results+1] = ref
+ end
+ end
+end
+
+function m.checkSameSimpleName(ref, sm)
+ if sm == '*' then
+ return true
+ end
+ if m.getSimpleName(ref) == sm then
+ return true
+ end
+ if ref.type == 'doc.type'
+ and ref.array == true then
+ return true
+ end
+ return false
+end
+
+function m.checkSameSimple(status, simple, data, mode, queue)
+ local ref = data.obj
+ local start = data.start
+ local force = data.force
+ if start > #simple then
+ return
+ end
+ for i = start, #simple do
+ local sm = simple[i]
+ if not force and not m.checkSameSimpleName(ref, sm) then
+ return
+ end
+ force = false
+ local cmode = mode
+ if i < #simple then
+ cmode = 'ref'
+ end
+ -- 检查 doc
+ local skipInfer = m.checkSameSimpleByBindDocs(status, ref, i, queue, cmode)
+ or m.checkSameSimpleByDoc(status, ref, i, queue, cmode)
+ if not skipInfer then
+ -- 穿透 self:func 与 mt:func
+ m.searchSameFieldsCrossMethod(status, ref, i, queue)
+ -- 穿透赋值
+ m.searchSameFieldsInValue(status, ref, i, queue, cmode)
+ -- 检查自己是字面量表的情况
+ m.checkSameSimpleInValueOfTable(status, ref, i, queue)
+ -- 检查自己作为 setmetatable 第一个参数的情况
+ m.checkSameSimpleInArg1OfSetMetaTable(status, ref, i, queue)
+ -- 检查自己作为 setmetatable 调用的情况
+ m.checkSameSimpleInValueOfCallMetaTable(status, ref, i, queue)
+ -- 检查自己是特殊变量的分支的情况
+ m.checkSameSimpleInSpecialBranch(status, ref, i, queue)
+ -- 检查自己是字面量字符串的分支情况
+ m.checkSameSimpleInString(status, ref, i, queue, cmode)
+ if cmode == 'ref' then
+ -- 检查形如 { a = f } 的情况
+ m.checkSameSimpleAsTableField(status, ref, i, queue)
+ -- 检查形如 return m 的情况
+ m.checkSameSimpleAsReturn(status, ref, i, queue)
+ -- 检查形如 a = f 的情况
+ m.checkSameSimpleAsSetValue(status, ref, i, queue)
+ end
+ end
+ if i == #simple then
+ break
+ end
+ ref = m.getNextRef(ref)
+ if not ref then
+ return
+ end
+ end
+ m.pushResult(status, mode, ref, simple)
+ local value = m.getObjectValue(ref)
+ if value then
+ m.pushResult(status, mode, value, simple)
+ end
+end
+
+function m.searchSameFields(status, simple, mode)
+ local queue = {}
+ if simple.mode == 'global' then
+ -- 全局变量开头
+ m.checkSameSimpleInGlobal(status, simple[1], simple.node, 1, queue)
+ elseif simple.mode == 'local' then
+ -- 局部变量开头
+ queue[1] = {
+ obj = simple.node,
+ start = 1,
+ }
+ local refs = simple.node.ref
+ if refs then
+ for i = 1, #refs do
+ queue[#queue+1] = {
+ obj = refs[i],
+ start = 1,
+ }
+ end
+ end
+ else
+ queue[1] = {
+ obj = simple.node,
+ start = 1,
+ }
+ end
+ local max = 0
+ for i = 1, 1e6 do
+ local data = queue[i]
+ if not data then
+ return
+ end
+ if not status.lock[data.obj] then
+ status.lock[data.obj] = true
+ max = max + 1
+ status.cache.count = status.cache.count + 1
+ m.checkSameSimple(status, simple, data, mode, queue)
+ if max >= 10000 then
+ logWarn('Queue too large!')
+ break
+ end
+ end
+ end
+end
+
+function m.getCallerInSameFile(status, func)
+ -- 搜索所有所在函数的调用者
+ local funcRefs = m.status(status)
+ m.searchRefOfValue(funcRefs, func)
+
+ local calls = {}
+ if #funcRefs.results == 0 then
+ return calls
+ end
+ for _, res in ipairs(funcRefs.results) do
+ local call = res.parent
+ if call.type == 'call' then
+ calls[#calls+1] = call
+ end
+ end
+ return calls
+end
+
+function m.getCallerCrossFiles(status, main)
+ if status.interface.link then
+ return status.interface.link(main.uri)
+ end
+ return {}
+end
+
+function m.searchRefsAsFunctionReturn(status, obj, mode)
+ if mode == 'def' then
+ return
+ end
+ if m.checkReturnMark(status, obj, true) then
+ return
+ end
+ status.results[#status.results+1] = obj
+ -- 搜索所在函数
+ local currentFunc = m.getParentFunction(obj)
+ local rtn = obj.parent
+ if rtn.type ~= 'return' then
+ return
+ end
+ -- 看看他是第几个返回值
+ local index
+ for i = 1, #rtn do
+ if obj == rtn[i] then
+ index = i
+ break
+ end
+ end
+ if not index then
+ return
+ end
+ local calls
+ if currentFunc.type == 'main' then
+ calls = m.getCallerCrossFiles(status, currentFunc)
+ else
+ calls = m.getCallerInSameFile(status, currentFunc)
+ end
+ -- 搜索调用者的返回值
+ if #calls == 0 then
+ return
+ end
+ local selects = {}
+ for i = 1, #calls do
+ local parent = calls[i].parent
+ if parent.type == 'select' and parent.index == index then
+ selects[#selects+1] = parent.parent
+ end
+ local extParent = calls[i].extParent
+ if extParent then
+ for j = 1, #extParent do
+ local ext = extParent[j]
+ if ext.type == 'select' and ext.index == index then
+ selects[#selects+1] = ext.parent
+ end
+ end
+ end
+ end
+ -- 搜索调用者的引用
+ for i = 1, #selects do
+ m.searchRefs(status, selects[i], 'ref')
+ end
+end
+
+function m.searchRefsAsFunctionSet(status, obj, mode)
+ local parent = obj.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'local'
+ or parent.type == 'setlocal'
+ or parent.type == 'setglobal'
+ or parent.type == 'setfield'
+ or parent.type == 'setmethod'
+ or parent.type == 'tablefield' then
+ m.searchRefs(status, parent, mode)
+ elseif parent.type == 'setindex'
+ or parent.type == 'tableindex' then
+ if parent.index == obj then
+ m.searchRefs(status, parent, mode)
+ end
+ end
+end
+
+function m.searchRefsAsFunction(status, obj, mode)
+ if obj.type ~= 'function'
+ and obj.type ~= 'table' then
+ return
+ end
+ m.searchRefsAsFunctionSet(status, obj, mode)
+ -- 检查自己作为返回函数时的引用
+ m.searchRefsAsFunctionReturn(status, obj, mode)
+end
+
+function m.cleanResults(results)
+ local mark = {}
+ for i = #results, 1, -1 do
+ local res = results[i]
+ if res.tag == 'self'
+ or mark[res] then
+ results[i] = results[#results]
+ results[#results] = nil
+ else
+ mark[res] = true
+ end
+ end
+end
+
+--function m.getRefCache(status, obj, mode)
+-- local cache = status.interface.cache and status.interface.cache()
+-- if not cache then
+-- return
+-- end
+-- if m.isGlobal(obj) then
+-- obj = m.getKeyName(obj)
+-- end
+-- if not cache[mode] then
+-- cache[mode] = {}
+-- end
+-- local sourceCache = cache[mode][obj]
+-- if sourceCache then
+-- return sourceCache
+-- end
+-- sourceCache = {}
+-- cache[mode][obj] = sourceCache
+-- return nil, function (results)
+-- for i = 1, #results do
+-- sourceCache[i] = results[i]
+-- end
+-- end
+--end
+
+function m.getRefCache(status, obj, mode)
+ local cache, globalCache
+ if status.depth == 1
+ and status.deep then
+ globalCache = status.interface.cache and status.interface.cache() or {}
+ end
+ cache = status.cache.refCache or {}
+ status.cache.refCache = cache
+ if m.isGlobal(obj) then
+ obj = m.getKeyName(obj)
+ end
+ if not cache[mode] then
+ cache[mode] = {}
+ end
+ if globalCache and not globalCache[mode] then
+ globalCache[mode] = {}
+ end
+ local sourceCache = globalCache and globalCache[mode][obj] or cache[mode][obj]
+ if sourceCache then
+ return sourceCache
+ end
+ sourceCache = {}
+ cache[mode][obj] = sourceCache
+ if globalCache then
+ globalCache[mode][obj] = sourceCache
+ end
+ return nil, function (results)
+ for i = 1, #results do
+ sourceCache[i] = results[i]
+ end
+ end
+end
+
+function m.searchRefs(status, obj, mode)
+ local cache, makeCache = m.getRefCache(status, obj, mode)
+ if cache then
+ for i = 1, #cache do
+ status.results[#status.results+1] = cache[i]
+ end
+ return
+ end
+
+ -- 检查单步引用
+ local res = m.getStepRef(status, obj, mode)
+ if res then
+ for i = 1, #res do
+ status.results[#status.results+1] = res[i]
+ end
+ end
+ -- 检查simple
+ if status.depth <= 100 then
+ local simple = m.getSimple(obj)
+ if simple then
+ m.searchSameFields(status, simple, mode)
+ end
+ else
+ if m.debugMode then
+ error('status.depth overflow')
+ elseif DEVELOP then
+ --log.warn(debug.traceback('status.depth overflow'))
+ logWarn('status.depth overflow')
+ end
+ end
+
+ m.cleanResults(status.results)
+
+ if makeCache then
+ makeCache(status.results)
+ end
+end
+
+function m.searchRefOfValue(status, obj)
+ local var = obj.parent
+ if var.type == 'local'
+ or var.type == 'set' then
+ return m.searchRefs(status, var, 'ref')
+ end
+end
+
+function m.allocInfer(o)
+ if type(o.type) == 'table' then
+ local infers = {}
+ for i = 1, #o.type do
+ infers[i] = {
+ type = o.type[i],
+ value = o.value,
+ source = o.source,
+ }
+ end
+ return infers
+ else
+ return {
+ [1] = o,
+ }
+ end
+end
+
+function m.mergeTypes(types)
+ local results = {}
+ local mark = {}
+ local hasAny
+ -- 这里把 any 去掉
+ for i = 1, #types do
+ local tp = types[i]
+ if tp == 'any' then
+ hasAny = true
+ end
+ if not mark[tp] and tp ~= 'any' then
+ mark[tp] = true
+ results[#results+1] = tp
+ end
+ end
+ if #results == 0 then
+ return 'any'
+ end
+ -- 只有显性的 nil 与 any 时,取 any
+ if #results == 1 then
+ if results[1] == 'nil' and hasAny then
+ return 'any'
+ else
+ return results[1]
+ end
+ end
+ -- 同时包含 number 与 integer 时,去掉 integer
+ if mark['number'] and mark['integer'] then
+ for i = 1, #results do
+ if results[i] == 'integer' then
+ tableRemove(results, i)
+ break
+ end
+ end
+ end
+ tableSort(results, function (a, b)
+ local sa = TypeSort[a] or 100
+ local sb = TypeSort[b] or 100
+ return sa < sb
+ end)
+ return tableConcat(results, '|')
+end
+
+function m.viewInferType(infers)
+ if not infers then
+ return 'any'
+ end
+ local mark = {}
+ local types = {}
+ local hasDoc
+ for i = 1, #infers do
+ local infer = infers[i]
+ local src = infer.source
+ if src.type == 'doc.class'
+ or src.type == 'doc.class.name'
+ or src.type == 'doc.type.name'
+ or src.type == 'doc.type.array'
+ or src.type == 'doc.type.generic' then
+ if infer.type ~= 'any' then
+ hasDoc = true
+ break
+ end
+ end
+ end
+ if hasDoc then
+ for i = 1, #infers do
+ local infer = infers[i]
+ local src = infer.source
+ if src.type == 'doc.class'
+ or src.type == 'doc.class.name'
+ or src.type == 'doc.type.name'
+ or src.type == 'doc.type.array'
+ or src.type == 'doc.type.generic'
+ or src.type == 'doc.type.enum'
+ or src.type == 'doc.resume' then
+ local tp = infer.type or 'any'
+ if not mark[tp] then
+ types[#types+1] = tp
+ end
+ mark[tp] = true
+ end
+ end
+ else
+ for i = 1, #infers do
+ local tp = infers[i].type or 'any'
+ if not mark[tp] then
+ types[#types+1] = tp
+ end
+ mark[tp] = true
+ end
+ end
+ return m.mergeTypes(types)
+end
+
+function m.checkTrue(status, source)
+ local newStatus = m.status(status)
+ m.searchInfer(newStatus, source)
+ -- 当前认为的结果
+ local current
+ for _, infer in ipairs(newStatus.results) do
+ -- 新的结果
+ local new
+ if infer.type == 'nil' then
+ new = false
+ elseif infer.type == 'boolean' then
+ if infer.value == true then
+ new = true
+ elseif infer.value == false then
+ new = false
+ end
+ end
+ if new ~= nil then
+ if current == nil then
+ current = new
+ else
+ -- 如果2个结果完全相反,则返回 nil 表示不确定
+ if new ~= current then
+ return nil
+ end
+ end
+ end
+ end
+ return current
+end
+
+--- 获取特定类型的字面量值
+function m.getInferLiteral(status, source, type)
+ local newStatus = m.status(status)
+ m.searchInfer(newStatus, source)
+ for _, infer in ipairs(newStatus.results) do
+ if infer.value ~= nil then
+ if type == nil or infer.type == type then
+ return infer.value
+ end
+ end
+ end
+ return nil
+end
+
+--- 是否包含某种类型
+function m.hasType(status, source, type)
+ m.searchInfer(status, source)
+ for _, infer in ipairs(status.results) do
+ if infer.type == type then
+ return true
+ end
+ end
+ return false
+end
+
+function m.isSameValue(status, a, b)
+ local statusA = m.status(status)
+ m.searchInfer(statusA, a)
+ local statusB = m.status(status)
+ m.searchInfer(statusB, b)
+ local infers = {}
+ for _, infer in ipairs(statusA.results) do
+ local literal = infer.value
+ if literal then
+ infers[literal] = false
+ end
+ end
+ for _, infer in ipairs(statusB.results) do
+ local literal = infer.value
+ if literal then
+ if infers[literal] == nil then
+ return false
+ end
+ infers[literal] = true
+ end
+ end
+ for k, v in pairs(infers) do
+ if v == false then
+ return false
+ end
+ end
+ return true
+end
+
+function m.inferCheckLiteralTableWithDocVararg(status, source)
+ if #source ~= 1 then
+ return
+ end
+ local vararg = source[1]
+ if vararg.type ~= 'varargs' then
+ return
+ end
+ local results = m.getVarargDocType(status, source)
+ status.results[#status.results+1] = {
+ type = m.viewInferType(results) .. '[]',
+ source = source,
+ }
+ return true
+end
+
+function m.inferCheckLiteral(status, source)
+ if source.type == 'string' then
+ status.results = m.allocInfer {
+ type = 'string',
+ value = source[1],
+ source = source,
+ }
+ return true
+ elseif source.type == 'nil' then
+ status.results = m.allocInfer {
+ type = 'nil',
+ value = NIL,
+ source = source,
+ }
+ return true
+ elseif source.type == 'boolean' then
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = source[1],
+ source = source,
+ }
+ return true
+ elseif source.type == 'number' then
+ if mathType(source[1]) == 'integer' then
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = source[1],
+ source = source,
+ }
+ return true
+ else
+ status.results = m.allocInfer {
+ type = 'number',
+ value = source[1],
+ source = source,
+ }
+ return true
+ end
+ elseif source.type == 'integer' then
+ status.results = m.allocInfer {
+ type = 'integer',
+ source = source,
+ }
+ return true
+ elseif source.type == 'table' then
+ if m.inferCheckLiteralTableWithDocVararg(status, source) then
+ return true
+ end
+ status.results = m.allocInfer {
+ type = 'table',
+ source = source,
+ }
+ return true
+ elseif source.type == 'function' then
+ status.results = m.allocInfer {
+ type = 'function',
+ source = source,
+ }
+ return true
+ elseif source.type == '...' then
+ status.results = m.allocInfer {
+ type = '...',
+ source = source,
+ }
+ return true
+ end
+end
+
+local function getDocAliasExtends(status, name)
+ if not status.interface.docType then
+ return nil
+ end
+ for _, doc in ipairs(status.interface.docType(name)) do
+ if doc.type == 'doc.alias.name' then
+ return m.viewInferType(m.getDocTypeNames(status, doc.parent.extends))
+ end
+ end
+ return nil
+end
+
+local function getDocTypeUnitName(status, unit, genericCallback)
+ local typeName
+ if unit.type == 'doc.type.name' then
+ typeName = getDocAliasExtends(status, unit[1]) or unit[1]
+ elseif unit.type == 'doc.type.function' then
+ typeName = 'function'
+ elseif unit.type == 'doc.type.array' then
+ typeName = getDocTypeUnitName(status, unit.node, genericCallback) .. '[]'
+ elseif unit.type == 'doc.type.generic' then
+ typeName = ('%s<%s, %s>'):format(
+ getDocTypeUnitName(status, unit.node, genericCallback),
+ m.viewInferType(m.getDocTypeNames(status, unit.key, genericCallback)),
+ m.viewInferType(m.getDocTypeNames(status, unit.value, genericCallback))
+ )
+ end
+ if unit.typeGeneric then
+ if genericCallback then
+ typeName = genericCallback(typeName, unit)
+ or ('<%s>'):format(typeName)
+ else
+ typeName = ('<%s>'):format(typeName)
+ end
+ end
+ return typeName
+end
+
+function m.getDocTypeNames(status, doc, genericCallback)
+ local results = {}
+ if not doc then
+ return results
+ end
+ for _, unit in ipairs(doc.types) do
+ local typeName = getDocTypeUnitName(status, unit, genericCallback)
+ results[#results+1] = {
+ type = typeName,
+ source = unit,
+ }
+ end
+ for _, enum in ipairs(doc.enums) do
+ results[#results+1] = {
+ type = enum[1],
+ source = enum,
+ }
+ end
+ for _, resume in ipairs(doc.resumes) do
+ if not resume.additional then
+ results[#results+1] = {
+ type = resume[1],
+ source = resume,
+ }
+ end
+ end
+ return results
+end
+
+function m.inferCheckDoc(status, source)
+ if source.type == 'doc.class.name' then
+ status.results[#status.results+1] = {
+ type = source[1],
+ source = source,
+ }
+ return true
+ end
+ if source.type == 'doc.class' then
+ status.results[#status.results+1] = {
+ type = source.class[1],
+ source = source,
+ }
+ return true
+ end
+ if source.type == 'doc.type' then
+ local results = m.getDocTypeNames(status, source)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+ end
+ if source.type == 'doc.field' then
+ local results = m.getDocTypeNames(status, source.extends)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+ end
+end
+
+function m.getVarargDocType(status, source)
+ local func = m.getParentFunction(source)
+ if not func then
+ return
+ end
+ if not func.args then
+ return
+ end
+ for _, arg in ipairs(func.args) do
+ if arg.type == '...' then
+ if arg.bindDocs then
+ for _, doc in ipairs(arg.bindDocs) do
+ if doc.type == 'doc.vararg' then
+ return m.getDocTypeNames(status, doc.vararg)
+ end
+ end
+ end
+ end
+ end
+end
+
+function m.inferCheckUpDocOfVararg(status, source)
+ if not source.vararg then
+ return
+ end
+ local results = m.getVarargDocType(status, source)
+ if not results then
+ return
+ end
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+end
+
+function m.inferCheckUpDoc(status, source)
+ if m.inferCheckUpDocOfVararg(status, source) then
+ return true
+ end
+ local parent = source.parent
+ if parent then
+ if parent.type == 'local'
+ or parent.type == 'setlocal'
+ or parent.type == 'setglobal' then
+ source = parent
+ end
+ if parent.type == 'setfield'
+ or parent.type == 'tablefield' then
+ if parent.field == source
+ or parent.value == source then
+ source = parent
+ end
+ end
+ if parent.type == 'setmethod' then
+ if parent.method == source
+ or parent.value == source then
+ source = parent
+ end
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'tableindex' then
+ if parent.index == source
+ or parent.value == source then
+ source = parent
+ end
+ end
+ end
+ local binds = source.bindDocs
+ if not binds then
+ return
+ end
+ status.results = {}
+ for _, doc in ipairs(binds) do
+ if doc.type == 'doc.class' then
+ status.results[#status.results+1] = {
+ type = doc.class[1],
+ source = doc,
+ }
+ -- ---@class Class
+ -- local x = { field = 1 }
+ -- 这种情况下,将字面量表接受为Class的定义
+ if source.value and source.value.type == 'table' then
+ status.results[#status.results+1] = {
+ type = source.value.type,
+ source = source.value,
+ }
+ end
+ return true
+ elseif doc.type == 'doc.type' then
+ local results = m.getDocTypeNames(status, doc)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+ elseif doc.type == 'doc.param' then
+ -- function (x) 的情况
+ if source.type == 'local'
+ and m.getName(source) == doc.param[1] then
+ if source.parent.type == 'funcargs'
+ or source.parent.type == 'in'
+ or source.parent.type == 'loop' then
+ local results = m.getDocTypeNames(status, doc.extends)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+ end
+ end
+ end
+ end
+end
+
+function m.inferCheckFieldDoc(status, source)
+ -- 检查 string[] 的情况
+ if source.type == 'getindex' then
+ local node = source.node
+ if not node then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchInfer(newStatus, node)
+ local ok
+ for _, infer in ipairs(newStatus.results) do
+ local src = infer.source
+ if src.type == 'doc.type.array' then
+ ok = true
+ status.results[#status.results+1] = {
+ type = infer.type:gsub('%[%]$', ''),
+ source = src.node,
+ }
+ end
+ end
+ return ok
+ end
+end
+
+function m.inferCheckUnary(status, source)
+ if source.type ~= 'unary' then
+ return
+ end
+ local op = source.op
+ if op.type == 'not' then
+ local checkTrue = m.checkTrue(status, source[1])
+ local value = nil
+ if checkTrue == true then
+ value = false
+ elseif checkTrue == false then
+ value = true
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ return true
+ elseif op.type == '#' then
+ status.results = m.allocInfer {
+ type = 'integer',
+ source = source,
+ }
+ return true
+ elseif op.type == '~' then
+ local l = m.getInferLiteral(status, source[1], 'integer')
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = l and ~l or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '-' then
+ local v = m.getInferLiteral(status, source[1], 'integer')
+ if v then
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = - v,
+ source = source,
+ }
+ return true
+ end
+ v = m.getInferLiteral(status, source[1], 'number')
+ status.results = m.allocInfer {
+ type = 'number',
+ value = v and -v or nil,
+ source = source,
+ }
+ return true
+ end
+end
+
+local function mathCheck(status, a, b)
+ local v1 = m.getInferLiteral(status, a, 'integer')
+ or m.getInferLiteral(status, a, 'number')
+ local v2 = m.getInferLiteral(status, b, 'integer')
+ or m.getInferLiteral(status, a, 'number')
+ local int = m.hasType(status, a, 'integer')
+ and m.hasType(status, b, 'integer')
+ and not m.hasType(status, a, 'number')
+ and not m.hasType(status, b, 'number')
+ return int and 'integer' or 'number', v1, v2
+end
+
+function m.inferCheckBinary(status, source)
+ if source.type ~= 'binary' then
+ return
+ end
+ local op = source.op
+ if op.type == 'and' then
+ local isTrue = m.checkTrue(status, source[1])
+ if isTrue == true then
+ m.searchInfer(status, source[2])
+ return true
+ elseif isTrue == false then
+ m.searchInfer(status, source[1])
+ return true
+ else
+ m.searchInfer(status, source[1])
+ m.searchInfer(status, source[2])
+ return true
+ end
+ elseif op.type == 'or' then
+ local isTrue = m.checkTrue(status, source[1])
+ if isTrue == true then
+ m.searchInfer(status, source[1])
+ return true
+ elseif isTrue == false then
+ m.searchInfer(status, source[2])
+ return true
+ else
+ m.searchInfer(status, source[1])
+ m.searchInfer(status, source[2])
+ return true
+ end
+ elseif op.type == '==' then
+ local value = m.isSameValue(status, source[1], source[2])
+ if value ~= nil then
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ return true
+ end
+ --local isSame = m.isSameDef(status, source[1], source[2])
+ --if isSame == true then
+ -- value = true
+ --else
+ -- value = nil
+ --end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ return true
+ elseif op.type == '~=' then
+ local value = m.isSameValue(status, source[1], source[2])
+ if value ~= nil then
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = not value,
+ source = source,
+ }
+ return true
+ end
+ --local isSame = m.isSameDef(status, source[1], source[2])
+ --if isSame == true then
+ -- value = false
+ --else
+ -- value = nil
+ --end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ return true
+ elseif op.type == '<=' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 <= v2
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '>=' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 >= v2
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '<' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 < v2
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '>' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 > v2
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '|' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 | v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '~' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 ~ v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '&' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 & v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '<<' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 << v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '>>' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 >> v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '..' then
+ local v1 = m.getInferLiteral(status, source[1], 'string')
+ local v2 = m.getInferLiteral(status, source[2], 'string')
+ local v
+ if v1 and v2 then
+ v = v1 .. v2
+ end
+ status.results = m.allocInfer {
+ type = 'string',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '^' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 ^ v2
+ end
+ status.results = m.allocInfer {
+ type = 'number',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '/' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 > v2
+ end
+ status.results = m.allocInfer {
+ type = 'number',
+ value = v,
+ source = source,
+ }
+ return true
+ -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数
+ elseif op.type == '+' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer{
+ type = int,
+ value = (v1 and v2) and (v1 + v2) or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '-' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer{
+ type = int,
+ value = (v1 and v2) and (v1 - v2) or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '*' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer {
+ type = int,
+ value = (v1 and v2) and (v1 * v2) or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '%' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer {
+ type = int,
+ value = (v1 and v2) and (v1 % v2) or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '//' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer {
+ type = int,
+ value = (v1 and v2) and (v1 // v2) or nil,
+ source = source,
+ }
+ return true
+ end
+end
+
+function m.inferByDef(status, obj)
+ if not status.cache.inferedDef then
+ status.cache.inferedDef = {}
+ end
+ if status.cache.inferedDef[obj] then
+ return
+ end
+ status.cache.inferedDef[obj] = true
+ local mark = {}
+ local newStatus = m.status(status, status.interface)
+ m.searchRefs(newStatus, obj, 'def')
+ for _, src in ipairs(newStatus.results) do
+ local inferStatus = m.status(newStatus)
+ m.searchInfer(inferStatus, src)
+ for _, infer in ipairs(inferStatus.results) do
+ if not mark[infer.source] then
+ mark[infer.source] = true
+ status.results[#status.results+1] = infer
+ end
+ end
+ end
+end
+
+local function inferBySetOfLocal(status, source)
+ if status.cache[source] then
+ return
+ end
+ status.cache[source] = true
+ local newStatus = m.status(status)
+ if source.value then
+ m.searchInfer(newStatus, source.value)
+ end
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ if ref.type == 'setlocal' then
+ break
+ end
+ m.searchInfer(newStatus, ref)
+ end
+ for _, infer in ipairs(newStatus.results) do
+ status.results[#status.results+1] = infer
+ end
+ end
+end
+
+function m.inferBySet(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ if source.type == 'local' then
+ inferBySetOfLocal(status, source)
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ inferBySetOfLocal(status, source.node)
+ end
+end
+
+function m.inferByCall(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ if not source.parent then
+ return
+ end
+ if source.parent.type ~= 'call' then
+ return
+ end
+ if source.parent.node == source then
+ status.results[#status.results+1] = {
+ type = 'function',
+ source = source,
+ }
+ return
+ end
+end
+
+function m.inferByGetTable(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ local next = source.next
+ if not next then
+ return
+ end
+ if next.type == 'getfield'
+ or next.type == 'getindex'
+ or next.type == 'getmethod'
+ or next.type == 'setfield'
+ or next.type == 'setindex'
+ or next.type == 'setmethod' then
+ status.results[#status.results+1] = {
+ type = 'table',
+ source = source,
+ }
+ end
+end
+
+function m.inferByUnary(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ local parent = source.parent
+ if not parent or parent.type ~= 'unary' then
+ return
+ end
+ local op = parent.op
+ if op.type == '#' then
+ status.results[#status.results+1] = {
+ type = 'string',
+ source = source
+ }
+ status.results[#status.results+1] = {
+ type = 'table',
+ source = source
+ }
+ elseif op.type == '~' then
+ status.results[#status.results+1] = {
+ type = 'integer',
+ source = source
+ }
+ elseif op.type == '-' then
+ status.results[#status.results+1] = {
+ type = 'number',
+ source = source
+ }
+ end
+end
+
+function m.inferByBinary(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ local parent = source.parent
+ if not parent or parent.type ~= 'binary' then
+ return
+ end
+ local op = parent.op
+ if op.type == '<='
+ or op.type == '>='
+ or op.type == '<'
+ or op.type == '>'
+ or op.type == '^'
+ or op.type == '/'
+ or op.type == '+'
+ or op.type == '-'
+ or op.type == '*'
+ or op.type == '%' then
+ status.results[#status.results+1] = {
+ type = 'number',
+ source = source,
+ }
+ elseif op.type == '|'
+ or op.type == '~'
+ or op.type == '&'
+ or op.type == '<<'
+ or op.type == '>>'
+ -- 整数的可能性比较高
+ or op.type == '//' then
+ status.results[#status.results+1] = {
+ type = 'integer',
+ source = source,
+ }
+ elseif op.type == '..' then
+ status.results[#status.results+1] = {
+ type = 'string',
+ source = source,
+ }
+ end
+end
+
+local function mergeFunctionReturnsByDoc(status, source, index, call)
+ if not source or source.type ~= 'function' then
+ return
+ end
+ if not source.bindDocs then
+ return
+ end
+ local returns = {}
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ returns[#returns+1] = rtn
+ end
+ end
+ end
+ local rtn = returns[index]
+ if not rtn then
+ return
+ end
+ local results = m.getDocTypeNames(status, rtn, function (typeName, typeUnit)
+ if not source.args or not call.args then
+ return
+ end
+ local name = typeUnit[1]
+ local generics = typeUnit.typeGeneric[name]
+ if not generics then
+ return
+ end
+ local first = generics[1]
+ if not first or first == typeUnit then
+ return
+ end
+ local docParam = m.getParentType(first, 'doc.param')
+ local paramName = docParam.param[1]
+ for i, arg in ipairs(source.args) do
+ if arg[1] == paramName then
+ local callArg = call.args[i]
+ if not callArg then
+ return
+ end
+ return m.viewInferType(m.searchInfer(status, callArg))
+ end
+ end
+ end)
+ if #results == 0 then
+ return
+ end
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+end
+
+local function mergeDocTypeFunctionReturns(status, source, index)
+ if not source.bindDocs then
+ return
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.type' then
+ for _, typeUnit in ipairs(doc.types) do
+ if typeUnit.type == 'doc.type.function' then
+ local rtn = typeUnit.returns[index]
+ if rtn then
+ local results = m.getDocTypeNames(status, rtn)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+local function mergeFunctionReturns(status, source, index, call)
+ local returns = source.returns
+ if not returns then
+ return
+ end
+ for i = 1, #returns do
+ local rtn = returns[i]
+ if rtn[index] then
+ if rtn[index].type == 'call' then
+ if not m.checkReturnMark(status, rtn[index]) then
+ m.checkReturnMark(status, rtn[index], true)
+ m.inferByCallReturnAndIndex(status, rtn[index], index)
+ end
+ else
+ local newStatus = m.status(status)
+ m.searchInfer(newStatus, rtn[index])
+ if #newStatus.results == 0 then
+ status.results[#status.results+1] = {
+ type = 'any',
+ source = rtn[index],
+ }
+ else
+ for _, infer in ipairs(newStatus.results) do
+ status.results[#status.results+1] = infer
+ end
+ end
+ end
+ end
+ end
+end
+
+function m.inferByCallReturnAndIndex(status, call, index)
+ local node = call.node
+ local newStatus = m.status(nil, status.interface)
+ m.searchRefs(newStatus, node, 'def')
+ local hasDocReturn
+ for _, src in ipairs(newStatus.results) do
+ if mergeDocTypeFunctionReturns(status, src, index) then
+ hasDocReturn = true
+ elseif mergeFunctionReturnsByDoc(status, src.value, index, call) then
+ hasDocReturn = true
+ end
+ end
+ if not hasDocReturn then
+ for _, src in ipairs(newStatus.results) do
+ if src.value and src.value.type == 'function' then
+ if not m.checkReturnMark(status, src.value, true) then
+ mergeFunctionReturns(status, src.value, index, call)
+ end
+ end
+ end
+ end
+end
+
+function m.inferByCallReturn(status, source)
+ if source.type == 'call' then
+ m.inferByCallReturnAndIndex(status, source, 1)
+ return
+ end
+ if source.type ~= 'select' then
+ if source.value and source.value.type == 'select' then
+ source = source.value
+ else
+ return
+ end
+ end
+ if not source.vararg or source.vararg.type ~= 'call' then
+ return
+ end
+ m.inferByCallReturnAndIndex(status, source.vararg, source.index)
+end
+
+function m.inferByPCallReturn(status, source)
+ if source.type ~= 'select' then
+ if source.value and source.value.type == 'select' then
+ source = source.value
+ else
+ return
+ end
+ end
+ local call = source.vararg
+ if not call or call.type ~= 'call' then
+ return
+ end
+ local node = call.node
+ local specialName = node.special
+ local func, index
+ if specialName == 'pcall' then
+ func = call.args[1]
+ index = source.index - 1
+ elseif specialName == 'xpcall' then
+ func = call.args[1]
+ index = source.index - 2
+ else
+ return
+ end
+ local newStatus = m.status(nil, status.interface)
+ m.searchRefs(newStatus, func, 'def')
+ for _, src in ipairs(newStatus.results) do
+ if src.value and src.value.type == 'function' then
+ mergeFunctionReturns(status, src.value, index)
+ end
+ end
+end
+
+function m.cleanInfers(infers)
+ local mark = {}
+ for i = #infers, 1, -1 do
+ local infer = infers[i]
+ local key = ('%s|%p'):format(infer.type, infer.source)
+ if mark[key] then
+ infers[i] = infers[#infers]
+ infers[#infers] = nil
+ else
+ mark[key] = true
+ end
+ end
+end
+
+function m.searchInfer(status, obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return
+ end
+ end
+ while true do
+ local value = m.getObjectValue(obj)
+ if not value or value == obj then
+ break
+ end
+ obj = value
+ end
+
+ local cache, makeCache = m.getRefCache(status, obj, 'infer')
+ if cache then
+ for i = 1, #cache do
+ status.results[#status.results+1] = cache[i]
+ end
+ return
+ end
+
+ if DEVELOP then
+ status.cache.clock = status.cache.clock or osClock()
+ end
+
+ if not status.cache.lockInfer then
+ status.cache.lockInfer = {}
+ end
+ if status.cache.lockInfer[obj] then
+ return
+ end
+ status.cache.lockInfer[obj] = true
+
+ local checked = m.inferCheckDoc(status, obj)
+ or m.inferCheckUpDoc(status, obj)
+ or m.inferCheckFieldDoc(status, obj)
+ or m.inferCheckLiteral(status, obj)
+ or m.inferCheckUnary(status, obj)
+ or m.inferCheckBinary(status, obj)
+ if checked then
+ m.cleanInfers(status.results)
+ if makeCache then
+ makeCache(status.results)
+ end
+ return
+ end
+
+ if status.deep then
+ m.inferByDef(status, obj)
+ end
+ m.inferBySet(status, obj)
+ m.inferByCall(status, obj)
+ m.inferByGetTable(status, obj)
+ m.inferByUnary(status, obj)
+ m.inferByBinary(status, obj)
+ m.inferByCallReturn(status, obj)
+ m.inferByPCallReturn(status, obj)
+ m.cleanInfers(status.results)
+ if makeCache then
+ makeCache(status.results)
+ end
+end
+
+--- 请求对象的引用,包括 `a.b.c` 形式
+--- 与 `return function` 形式。
+--- 不穿透 `setmetatable` ,考虑由
+--- 业务层进行反向 def 搜索。
+function m.requestReference(obj, interface, deep)
+ local status = m.status(nil, interface)
+ status.deep = deep
+ -- 根据 field 搜索引用
+ m.searchRefs(status, obj, 'ref')
+
+ m.searchRefsAsFunction(status, obj, 'ref')
+
+ if m.debugMode then
+ print('count:', status.cache.count)
+ end
+
+ return status.results, status.cache.count
+end
+
+--- 请求对象的定义,包括 `a.b.c` 形式
+--- 与 `return function` 形式。
+--- 穿透 `setmetatable` 。
+function m.requestDefinition(obj, interface, deep)
+ local status = m.status(nil, interface)
+ status.deep = deep
+ -- 根据 field 搜索定义
+ m.searchRefs(status, obj, 'def')
+
+ return status.results, status.cache.count
+end
+
+--- 请求对象的域
+function m.requestFields(obj, interface, deep)
+ local status = m.status(nil, interface)
+ status.deep = deep
+
+ m.searchFields(status, obj)
+
+ return status.results, status.cache.count
+end
+
+--- 请求对象的类型推测
+function m.requestInfer(obj, interface, deep)
+ local status = m.status(nil, interface)
+ status.deep = deep
+ m.searchInfer(status, obj)
+
+ return status.results, status.cache.count
+end
+
+return m
diff --git a/script/parser/init.lua b/script/parser/init.lua
new file mode 100644
index 00000000..ba40d145
--- /dev/null
+++ b/script/parser/init.lua
@@ -0,0 +1,12 @@
+local api = {
+ grammar = require 'parser.grammar',
+ parse = require 'parser.parse',
+ compile = require 'parser.compile',
+ split = require 'parser.split',
+ calcline = require 'parser.calcline',
+ lines = require 'parser.lines',
+ guide = require 'parser.guide',
+ luadoc = require 'parser.luadoc',
+}
+
+return api
diff --git a/script/parser/lines.lua b/script/parser/lines.lua
new file mode 100644
index 00000000..ee6b4f41
--- /dev/null
+++ b/script/parser/lines.lua
@@ -0,0 +1,45 @@
+local m = require 'lpeglabel'
+
+_ENV = nil
+
+local function Line(start, line, range, finish)
+ line.start = start
+ line.finish = finish - 1
+ line.range = range - 1
+ return line
+end
+
+local function Space(...)
+ local line = {...}
+ local sp = 0
+ local tab = 0
+ for i = 1, #line do
+ if line[i] == ' ' then
+ sp = sp + 1
+ elseif line[i] == '\t' then
+ tab = tab + 1
+ end
+ line[i] = nil
+ end
+ line.sp = sp
+ line.tab = tab
+ return line
+end
+
+local parser = m.P{
+'Lines',
+Lines = m.Ct(m.V'Line'^0 * m.V'LastLine'),
+Line = m.Cp() * m.V'Indent' * (1 - m.V'Nl')^0 * m.Cp() * m.V'Nl' * m.Cp() / Line,
+LastLine= m.Cp() * m.V'Indent' * (1 - m.V'Nl')^0 * m.Cp() * m.Cp() / Line,
+Nl = m.P'\r\n' + m.S'\r\n',
+Indent = m.C(m.S' \t')^0 / Space,
+}
+
+return function (self, text)
+ local lines, err = parser:match(text)
+ if not lines then
+ return nil, err
+ end
+
+ return lines
+end
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
new file mode 100644
index 00000000..b31c4baf
--- /dev/null
+++ b/script/parser/luadoc.lua
@@ -0,0 +1,991 @@
+local m = require 'lpeglabel'
+local re = require 'parser.relabel'
+local lines = require 'parser.lines'
+local guide = require 'parser.guide'
+
+local TokenTypes, TokenStarts, TokenFinishs, TokenContents
+local Ci, Offset, pushError, Ct, NextComment
+local parseType
+local Parser = re.compile([[
+Main <- (Token / Sp)*
+Sp <- %s+
+X16 <- [a-fA-F0-9]
+Word <- [a-zA-Z0-9_]
+Token <- Name / String / Symbol
+Name <- ({} {[a-zA-Z0-9_] [a-zA-Z0-9_.*]*} {})
+ -> Name
+String <- ({} StringDef {})
+ -> String
+StringDef <- '"'
+ {~(Esc / !'"' .)*~} -> 1
+ ('"'?)
+ / "'"
+ {~(Esc / !"'" .)*~} -> 1
+ ("'"?)
+ / ('[' {:eq: '='* :} '['
+ {(!StringClose .)*} -> 1
+ (StringClose?))
+StringClose <- ']' =eq ']'
+Esc <- '\' -> ''
+ EChar
+EChar <- 'a' -> ea
+ / 'b' -> eb
+ / 'f' -> ef
+ / 'n' -> en
+ / 'r' -> er
+ / 't' -> et
+ / 'v' -> ev
+ / '\'
+ / '"'
+ / "'"
+ / %nl
+ / ('z' (%nl / %s)*) -> ''
+ / ('x' {X16 X16}) -> Char16
+ / ([0-9] [0-9]? [0-9]?) -> Char10
+ / ('u{' {Word*} '}') -> CharUtf8
+Symbol <- ({} {
+ ':'
+ / '|'
+ / ','
+ / '[]'
+ / '<'
+ / '>'
+ / '('
+ / ')'
+ / '?'
+ / '...'
+ / '+'
+ } {})
+ -> Symbol
+]], {
+ s = m.S' \t',
+ ea = '\a',
+ eb = '\b',
+ ef = '\f',
+ en = '\n',
+ er = '\r',
+ et = '\t',
+ ev = '\v',
+ Char10 = function (char)
+ char = tonumber(char)
+ if not char or char < 0 or char > 255 then
+ return ''
+ end
+ return string.char(char)
+ end,
+ Char16 = function (char)
+ return string.char(tonumber(char, 16))
+ end,
+ CharUtf8 = function (char)
+ if #char == 0 then
+ return ''
+ end
+ local v = tonumber(char, 16)
+ if not v then
+ return ''
+ end
+ if v >= 0 and v <= 0x10FFFF then
+ return utf8.char(v)
+ end
+ return ''
+ end,
+ Name = function (start, content, finish)
+ Ci = Ci + 1
+ TokenTypes[Ci] = 'name'
+ TokenStarts[Ci] = start
+ TokenFinishs[Ci] = finish - 1
+ TokenContents[Ci] = content
+ end,
+ String = function (start, content, finish)
+ Ci = Ci + 1
+ TokenTypes[Ci] = 'string'
+ TokenStarts[Ci] = start
+ TokenFinishs[Ci] = finish - 1
+ TokenContents[Ci] = content
+ end,
+ Symbol = function (start, content, finish)
+ Ci = Ci + 1
+ TokenTypes[Ci] = 'symbol'
+ TokenStarts[Ci] = start
+ TokenFinishs[Ci] = finish - 1
+ TokenContents[Ci] = content
+ end,
+})
+
+local function trim(str)
+ return str:match '^%s*(%S+)%s*$'
+end
+
+local function parseTokens(text, offset)
+ Ct = offset
+ Ci = 0
+ Offset = offset
+ TokenTypes = {}
+ TokenStarts = {}
+ TokenFinishs = {}
+ TokenContents = {}
+ Parser:match(text)
+ Ci = 0
+end
+
+local function peekToken()
+ return TokenTypes[Ci+1], TokenContents[Ci+1]
+end
+
+local function nextToken()
+ Ci = Ci + 1
+ if not TokenTypes[Ci] then
+ Ci = Ci - 1
+ return nil
+ end
+ return TokenTypes[Ci], TokenContents[Ci]
+end
+
+local function checkToken(tp, content, offset)
+ offset = offset or 0
+ return TokenTypes[Ci + offset] == tp
+ and TokenContents[Ci + offset] == content
+end
+
+local function getStart()
+ if Ci == 0 then
+ return Offset
+ end
+ return TokenStarts[Ci] + Offset
+end
+
+local function getFinish()
+ if Ci == 0 then
+ return Offset
+ end
+ return TokenFinishs[Ci] + Offset
+end
+
+local function try(callback)
+ local savePoint = Ci
+ -- rollback
+ local suc = callback()
+ if not suc then
+ Ci = savePoint
+ end
+ return suc
+end
+
+local function parseName(tp, parent)
+ local nameTp, nameText = peekToken()
+ if nameTp ~= 'name' then
+ return nil
+ end
+ nextToken()
+ local class = {
+ type = tp,
+ start = getStart(),
+ finish = getFinish(),
+ parent = parent,
+ [1] = nameText,
+ }
+ return class
+end
+
+local function parseClass(parent)
+ local result = {
+ type = 'doc.class',
+ parent = parent,
+ }
+ result.class = parseName('doc.class.name', result)
+ if not result.class then
+ pushError {
+ type = 'LUADOC_MISS_CLASS_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.start = getStart()
+ result.finish = getFinish()
+ if not peekToken() then
+ return result
+ end
+ nextToken()
+ if not checkToken('symbol', ':') then
+ pushError {
+ type = 'LUADOC_MISS_EXTENDS_SYMBOL',
+ start = result.finish + 1,
+ finish = getStart() - 1,
+ }
+ return result
+ end
+ result.extends = parseName('doc.extends.name', result)
+ if not result.extends then
+ pushError {
+ type = 'LUADOC_MISS_CLASS_EXTENDS_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function nextSymbolOrError(symbol)
+ if checkToken('symbol', symbol, 1) then
+ nextToken()
+ return true
+ end
+ pushError {
+ type = 'LUADOC_MISS_SYMBOL',
+ start = getFinish(),
+ finish = getFinish(),
+ info = {
+ symbol = symbol,
+ }
+ }
+ return false
+end
+
+local function parseTypeUnitArray(node)
+ if not checkToken('symbol', '[]', 1) then
+ return nil
+ end
+ nextToken()
+ local result = {
+ type = 'doc.type.array',
+ start = node.start,
+ finish = getFinish(),
+ node = node,
+ }
+ return result
+end
+
+local function parseTypeUnitGeneric(node)
+ if not checkToken('symbol', '<', 1) then
+ return nil
+ end
+ if not nextSymbolOrError('<') then
+ return nil
+ end
+ local key = parseType(node)
+ if not key or not nextSymbolOrError(',') then
+ return nil
+ end
+ local value = parseType(node)
+ if not value then
+ return nil
+ end
+ nextSymbolOrError('>')
+ local result = {
+ type = 'doc.type.generic',
+ start = node.start,
+ finish = getFinish(),
+ node = node,
+ key = key,
+ value = value,
+ }
+ return result
+end
+
+local function parseTypeUnitFunction()
+ local typeUnit = {
+ type = 'doc.type.function',
+ start = getStart(),
+ args = {},
+ returns = {},
+ }
+ if not nextSymbolOrError('(') then
+ return nil
+ end
+ while true do
+ if checkToken('symbol', ')', 1) then
+ nextToken()
+ break
+ end
+ local arg = {
+ type = 'doc.type.arg',
+ parent = typeUnit,
+ }
+ arg.name = parseName('doc.type.name', arg)
+ if not arg.name then
+ pushError {
+ type = 'LUADOC_MISS_ARG_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ break
+ end
+ if not arg.start then
+ arg.start = arg.name.start
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ arg.optional = true
+ end
+ arg.finish = getFinish()
+ if not nextSymbolOrError(':') then
+ break
+ end
+ arg.extends = parseType(arg)
+ if not arg.extends then
+ break
+ end
+ arg.finish = getFinish()
+ typeUnit.args[#typeUnit.args+1] = arg
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ else
+ nextSymbolOrError(')')
+ break
+ end
+ end
+ if checkToken('symbol', ':', 1) then
+ nextToken()
+ while true do
+ local rtn = parseType(typeUnit)
+ if not rtn then
+ break
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ rtn.optional = true
+ end
+ typeUnit.returns[#typeUnit.returns+1] = rtn
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ else
+ break
+ end
+ end
+ end
+ typeUnit.finish = getFinish()
+ return typeUnit
+end
+
+local function parseTypeUnit(parent, content)
+ local result
+ if content == 'fun' then
+ result = parseTypeUnitFunction()
+ end
+ if not result then
+ result = {
+ type = 'doc.type.name',
+ start = getStart(),
+ finish = getFinish(),
+ [1] = content,
+ }
+ end
+ if not result then
+ return nil
+ end
+ result.parent = parent
+ while true do
+ local newResult = parseTypeUnitArray(result)
+ or parseTypeUnitGeneric(result)
+ if not newResult then
+ break
+ end
+ result = newResult
+ end
+ return result
+end
+
+local function parseResume()
+ local result = {
+ type = 'doc.resume'
+ }
+
+ if checkToken('symbol', '>', 1) then
+ nextToken()
+ result.default = true
+ end
+
+ if checkToken('symbol', '+', 1) then
+ nextToken()
+ result.additional = true
+ end
+
+ local tp = peekToken()
+ if tp ~= 'string' then
+ pushError {
+ type = 'LUADOC_MISS_STRING',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ local _, str = nextToken()
+ result[1] = str
+ result.start = getStart()
+ result.finish = getFinish()
+ return result
+end
+
+function parseType(parent)
+ local result = {
+ type = 'doc.type',
+ parent = parent,
+ types = {},
+ enums = {},
+ resumes = {},
+ }
+ result.start = getStart()
+ while true do
+ local tp, content = peekToken()
+ if not tp then
+ break
+ end
+ if tp == 'name' then
+ nextToken()
+ local typeUnit = parseTypeUnit(result, content)
+ if not typeUnit then
+ break
+ end
+ result.types[#result.types+1] = typeUnit
+ if not result.start then
+ result.start = typeUnit.start
+ end
+ elseif tp == 'string' then
+ nextToken()
+ local typeEnum = {
+ type = 'doc.type.enum',
+ start = getStart(),
+ finish = getFinish(),
+ parent = result,
+ [1] = content,
+ }
+ result.enums[#result.enums+1] = typeEnum
+ if not result.start then
+ result.start = typeEnum.start
+ end
+ elseif tp == 'symbol' and content == '...' then
+ nextToken()
+ local vararg = {
+ type = 'doc.type.name',
+ start = getStart(),
+ finish = getFinish(),
+ parent = result,
+ [1] = content,
+ }
+ result.types[#result.types+1] = vararg
+ if not result.start then
+ result.start = vararg.start
+ end
+ end
+ if not checkToken('symbol', '|', 1) then
+ break
+ end
+ nextToken()
+ end
+ result.finish = getFinish()
+
+ while true do
+ local nextComm = NextComment('peek')
+ if nextComm and nextComm.text:sub(1, 2) == '-|' then
+ NextComment()
+ local finishPos = nextComm.text:find('#', 3) or #nextComm.text
+ parseTokens(nextComm.text:sub(3, finishPos), nextComm.start + 1)
+ local resume = parseResume()
+ if resume then
+ resume.comment = nextComm.text:match('#%s*(.+)', 3)
+ result.resumes[#result.resumes+1] = resume
+ result.finish = resume.finish
+ end
+ else
+ break
+ end
+ end
+
+ if #result.types == 0 and #result.enums == 0 and #result.resumes == 0 then
+ pushError {
+ type = 'LUADOC_MISS_TYPE_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ return result
+end
+
+local function parseAlias()
+ local result = {
+ type = 'doc.alias',
+ }
+ result.alias = parseName('doc.alias.name', result)
+ if not result.alias then
+ pushError {
+ type = 'LUADOC_MISS_ALIAS_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.start = getStart()
+ result.extends = parseType(result)
+ if not result.extends then
+ pushError {
+ type = 'LUADOC_MISS_ALIAS_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseParam()
+ local result = {
+ type = 'doc.param',
+ }
+ result.param = parseName('doc.param.name', result)
+ if not result.param then
+ pushError {
+ type = 'LUADOC_MISS_PARAM_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ result.optional = true
+ end
+ result.start = result.param.start
+ result.finish = getFinish()
+ result.extends = parseType(result)
+ if not result.extends then
+ pushError {
+ type = 'LUADOC_MISS_PARAM_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseReturn()
+ local result = {
+ type = 'doc.return',
+ returns = {},
+ }
+ while true do
+ local docType = parseType(result)
+ if not docType then
+ break
+ end
+ if not result.start then
+ result.start = docType.start
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ docType.optional = true
+ end
+ docType.name = parseName('doc.return.name', docType)
+ result.returns[#result.returns+1] = docType
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
+ end
+ if #result.returns == 0 then
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseField()
+ local result = {
+ type = 'doc.field',
+ }
+ try(function ()
+ local tp, value = nextToken()
+ if tp == 'name' then
+ if value == 'public'
+ or value == 'protected'
+ or value == 'private' then
+ result.visible = value
+ result.start = getStart()
+ return true
+ end
+ end
+ return false
+ end)
+ result.field = parseName('doc.field.name', result)
+ if not result.field then
+ pushError {
+ type = 'LUADOC_MISS_FIELD_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ if not result.start then
+ result.start = result.field.start
+ end
+ result.extends = parseType(result)
+ if not result.extends then
+ pushError {
+ type = 'LUADOC_MISS_FIELD_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseGeneric()
+ local result = {
+ type = 'doc.generic',
+ generics = {},
+ }
+ while true do
+ local object = {
+ type = 'doc.generic.object',
+ parent = result,
+ }
+ object.generic = parseName('doc.generic.name', object)
+ if not object.generic then
+ pushError {
+ type = 'LUADOC_MISS_GENERIC_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ object.start = object.generic.start
+ if not result.start then
+ result.start = object.start
+ end
+ if checkToken('symbol', ':', 1) then
+ nextToken()
+ object.extends = parseType(object)
+ end
+ object.finish = getFinish()
+ result.generics[#result.generics+1] = object
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseVararg()
+ local result = {
+ type = 'doc.vararg',
+ }
+ result.vararg = parseType(result)
+ if not result.vararg then
+ pushError {
+ type = 'LUADOC_MISS_VARARG_TYPE',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return
+ end
+ result.start = result.vararg.start
+ result.finish = result.vararg.finish
+ return result
+end
+
+local function parseOverload()
+ local tp, name = peekToken()
+ if tp ~= 'name' or name ~= 'fun' then
+ pushError {
+ type = 'LUADOC_MISS_FUN_AFTER_OVERLOAD',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ nextToken()
+ local result = {
+ type = 'doc.overload',
+ }
+ result.overload = parseTypeUnitFunction()
+ if not result.overload then
+ return nil
+ end
+ result.overload.parent = result
+ result.start = result.overload.start
+ result.finish = result.overload.finish
+ return result
+end
+
+local function parseDeprecated()
+ return {
+ type = 'doc.deprecated',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+end
+
+local function parseMeta()
+ return {
+ type = 'doc.meta',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+end
+
+local function parseVersion()
+ local result = {
+ type = 'doc.version',
+ versions = {},
+ }
+ while true do
+ local tp, text = nextToken()
+ if not tp then
+ pushError {
+ type = 'LUADOC_MISS_VERSION',
+ start = getStart(),
+ finish = getFinish(),
+ }
+ break
+ end
+ if not result.start then
+ result.start = getStart()
+ end
+ local version = {
+ type = 'doc.version.unit',
+ start = getStart(),
+ }
+ if tp == 'symbol' then
+ if text == '>' then
+ version.ge = true
+ elseif text == '<' then
+ version.le = true
+ end
+ tp, text = nextToken()
+ end
+ if tp ~= 'name' then
+ pushError {
+ type = 'LUADOC_MISS_VERSION',
+ start = getStart(),
+ finish = getFinish(),
+ }
+ break
+ end
+ version.version = tonumber(text) or text
+ version.finish = getFinish()
+ result.versions[#result.versions+1] = version
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
+ end
+ if #result.versions == 0 then
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function convertTokens()
+ local tp, text = nextToken()
+ if not tp then
+ return
+ end
+ if tp ~= 'name' then
+ pushError {
+ type = 'LUADOC_MISS_CATE_NAME',
+ start = getStart(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ if text == 'class' then
+ return parseClass()
+ elseif text == 'type' then
+ return parseType()
+ elseif text == 'alias' then
+ return parseAlias()
+ elseif text == 'param' then
+ return parseParam()
+ elseif text == 'return' then
+ return parseReturn()
+ elseif text == 'field' then
+ return parseField()
+ elseif text == 'generic' then
+ return parseGeneric()
+ elseif text == 'vararg' then
+ return parseVararg()
+ elseif text == 'overload' then
+ return parseOverload()
+ elseif text == 'deprecated' then
+ return parseDeprecated()
+ elseif text == 'meta' then
+ return parseMeta()
+ elseif text == 'version' then
+ return parseVersion()
+ end
+end
+
+local function buildLuaDoc(comment)
+ local text = comment.text
+ if text:sub(1, 1) ~= '-' then
+ return
+ end
+ if text:sub(2, 2) ~= '@' then
+ return {
+ type = 'doc.comment',
+ start = comment.start,
+ finish = comment.finish,
+ comment = comment,
+ }
+ end
+ local finishPos = text:find('@', 3)
+ local doc, lastComment
+ if finishPos then
+ doc = text:sub(3, finishPos - 1)
+ lastComment = text:sub(finishPos)
+ else
+ doc = text:sub(3)
+ end
+
+ parseTokens(doc, comment.start + 1)
+ local result = convertTokens()
+ if result then
+ result.comment = lastComment
+ end
+
+ return result
+end
+
+local function isNextLine(lns, binded, doc)
+ if not binded then
+ return false
+ end
+ local lastDoc = binded[#binded]
+ local lastRow = guide.positionOf(lns, lastDoc.finish)
+ local newRow = guide.positionOf(lns, doc.start)
+ return newRow - lastRow == 1
+end
+
+local function bindGeneric(binded)
+ local generics = {}
+ for _, doc in ipairs(binded) do
+ if doc.type == 'doc.generic' then
+ for _, obj in ipairs(doc.generics) do
+ local name = obj.generic[1]
+ generics[name] = {}
+ end
+ elseif doc.type == 'doc.param'
+ or doc.type == 'doc.return' then
+ guide.eachSourceType(doc, 'doc.type.name', function (src)
+ local name = src[1]
+ if generics[name] then
+ generics[name][#generics[name]+1] = src
+ src.typeGeneric = generics
+ end
+ end)
+ end
+ end
+end
+
+local function bindDoc(state, lns, binded)
+ if not binded then
+ return
+ end
+ local lastDoc = binded[#binded]
+ if not lastDoc then
+ return
+ end
+ local bindSources = {}
+ for _, doc in ipairs(binded) do
+ doc.bindGroup = binded
+ doc.bindSources = bindSources
+ end
+ bindGeneric(binded)
+ local row = guide.positionOf(lns, lastDoc.finish)
+ local start, finish = guide.lineRange(lns, row + 1)
+ if start >= finish then
+ -- 空行
+ return
+ end
+ guide.eachSourceBetween(state.ast, start, finish, function (src)
+ if src.start and src.start < start then
+ return
+ end
+ if src.type == 'local'
+ or src.type == 'setlocal'
+ or src.type == 'setglobal'
+ or src.type == 'setfield'
+ or src.type == 'setmethod'
+ or src.type == 'setindex'
+ or src.type == 'tablefield'
+ or src.type == 'tableindex'
+ or src.type == 'function'
+ or src.type == '...' then
+ src.bindDocs = binded
+ bindSources[#bindSources+1] = src
+ end
+ end)
+end
+
+local function bindDocs(state)
+ local lns = lines(nil, state.lua)
+ local binded
+ for _, doc in ipairs(state.ast.docs) do
+ if not isNextLine(lns, binded, doc) then
+ bindDoc(state, lns, binded)
+ binded = {}
+ state.ast.docs.groups[#state.ast.docs.groups+1] = binded
+ end
+ binded[#binded+1] = doc
+ end
+ bindDoc(state, lns, binded)
+end
+
+return function (_, state)
+ local ast = state.ast
+ local comments = state.comms
+ table.sort(comments, function (a, b)
+ return a.start < b.start
+ end)
+ ast.docs = {
+ type = 'doc',
+ parent = ast,
+ groups = {},
+ }
+
+ pushError = state.pushError
+
+ local ci = 1
+ NextComment = function (peek)
+ local comment = comments[ci]
+ if not peek then
+ ci = ci + 1
+ end
+ return comment
+ end
+
+ while true do
+ local comment = NextComment()
+ if not comment then
+ break
+ end
+ local doc = buildLuaDoc(comment)
+ if doc then
+ ast.docs[#ast.docs+1] = doc
+ doc.parent = ast.docs
+ if ast.start > doc.start then
+ ast.start = doc.start
+ end
+ if ast.finish < doc.finish then
+ ast.finish = doc.finish
+ end
+ end
+ end
+
+ if #ast.docs == 0 then
+ return
+ end
+
+ bindDocs(state)
+end
diff --git a/script/parser/parse.lua b/script/parser/parse.lua
new file mode 100644
index 00000000..f813cc59
--- /dev/null
+++ b/script/parser/parse.lua
@@ -0,0 +1,49 @@
+local ast = require 'parser.ast'
+
+return function (self, lua, mode, version)
+ local errs = {}
+ local diags = {}
+ local comms = {}
+ local state = {
+ version = version,
+ lua = lua,
+ root = {},
+ errs = errs,
+ diags = diags,
+ comms = comms,
+ pushError = function (err)
+ if err.finish < err.start then
+ err.finish = err.start
+ end
+ local last = errs[#errs]
+ if last then
+ if last.start <= err.start and last.finish >= err.finish then
+ return
+ end
+ end
+ err.level = err.level or 'error'
+ errs[#errs+1] = err
+ return err
+ end,
+ pushDiag = function (code, info)
+ if not diags[code] then
+ diags[code] = {}
+ end
+ diags[code][#diags[code]+1] = info
+ end,
+ pushComment = function (comment)
+ comms[#comms+1] = comment
+ end
+ }
+ ast.init(state)
+ local suc, res, err = xpcall(self.grammar, debug.traceback, self, lua, mode)
+ ast.close()
+ if not suc then
+ return nil, res
+ end
+ if not res then
+ state.pushError(err)
+ end
+ state.ast = res
+ return state
+end
diff --git a/script/parser/relabel.lua b/script/parser/relabel.lua
new file mode 100644
index 00000000..ac902403
--- /dev/null
+++ b/script/parser/relabel.lua
@@ -0,0 +1,361 @@
+-- $Id: re.lua,v 1.44 2013/03/26 20:11:40 roberto Exp $
+
+-- imported functions and modules
+local tonumber, type, print, error = tonumber, type, print, error
+local pcall = pcall
+local setmetatable = setmetatable
+local tinsert, concat = table.insert, table.concat
+local rep = string.rep
+local m = require"lpeglabel"
+
+-- 'm' will be used to parse expressions, and 'mm' will be used to
+-- create expressions; that is, 're' runs on 'm', creating patterns
+-- on 'mm'
+local mm = m
+
+-- pattern's metatable
+local mt = getmetatable(mm.P(0))
+
+
+
+-- No more global accesses after this point
+_ENV = nil
+
+
+local any = m.P(1)
+local dummy = mm.P(false)
+
+
+local errinfo = {
+ NoPatt = "no pattern found",
+ ExtraChars = "unexpected characters after the pattern",
+
+ ExpPatt1 = "expected a pattern after '/'",
+
+ ExpPatt2 = "expected a pattern after '&'",
+ ExpPatt3 = "expected a pattern after '!'",
+
+ ExpPatt4 = "expected a pattern after '('",
+ ExpPatt5 = "expected a pattern after ':'",
+ ExpPatt6 = "expected a pattern after '{~'",
+ ExpPatt7 = "expected a pattern after '{|'",
+
+ ExpPatt8 = "expected a pattern after '<-'",
+
+ ExpPattOrClose = "expected a pattern or closing '}' after '{'",
+
+ ExpNumName = "expected a number, '+', '-' or a name (no space) after '^'",
+ ExpCap = "expected a string, number, '{}' or name after '->'",
+
+ ExpName1 = "expected the name of a rule after '=>'",
+ ExpName2 = "expected the name of a rule after '=' (no space)",
+ ExpName3 = "expected the name of a rule after '<' (no space)",
+
+ ExpLab1 = "expected a label after '{'",
+
+ ExpNameOrLab = "expected a name or label after '%' (no space)",
+
+ ExpItem = "expected at least one item after '[' or '^'",
+
+ MisClose1 = "missing closing ')'",
+ MisClose2 = "missing closing ':}'",
+ MisClose3 = "missing closing '~}'",
+ MisClose4 = "missing closing '|}'",
+ MisClose5 = "missing closing '}'", -- for the captures
+
+ MisClose6 = "missing closing '>'",
+ MisClose7 = "missing closing '}'", -- for the labels
+
+ MisClose8 = "missing closing ']'",
+
+ MisTerm1 = "missing terminating single quote",
+ MisTerm2 = "missing terminating double quote",
+}
+
+local function expect (pattern, label)
+ return pattern + m.T(label)
+end
+
+
+-- Pre-defined names
+local Predef = { nl = m.P"\n" }
+
+
+local mem
+local fmem
+local gmem
+
+
+local function updatelocale ()
+ mm.locale(Predef)
+ Predef.a = Predef.alpha
+ Predef.c = Predef.cntrl
+ Predef.d = Predef.digit
+ Predef.g = Predef.graph
+ Predef.l = Predef.lower
+ Predef.p = Predef.punct
+ Predef.s = Predef.space
+ Predef.u = Predef.upper
+ Predef.w = Predef.alnum
+ Predef.x = Predef.xdigit
+ Predef.A = any - Predef.a
+ Predef.C = any - Predef.c
+ Predef.D = any - Predef.d
+ Predef.G = any - Predef.g
+ Predef.L = any - Predef.l
+ Predef.P = any - Predef.p
+ Predef.S = any - Predef.s
+ Predef.U = any - Predef.u
+ Predef.W = any - Predef.w
+ Predef.X = any - Predef.x
+ mem = {} -- restart memoization
+ fmem = {}
+ gmem = {}
+ local mt = {__mode = "v"}
+ setmetatable(mem, mt)
+ setmetatable(fmem, mt)
+ setmetatable(gmem, mt)
+end
+
+
+updatelocale()
+
+
+
+local I = m.P(function (s,i) print(i, s:sub(1, i-1)); return i end)
+
+
+local function getdef (id, defs)
+ local c = defs and defs[id]
+ if not c then
+ error("undefined name: " .. id)
+ end
+ return c
+end
+
+
+local function mult (p, n)
+ local np = mm.P(true)
+ while n >= 1 do
+ if n%2 >= 1 then np = np * p end
+ p = p * p
+ n = n/2
+ end
+ return np
+end
+
+local function equalcap (s, i, c)
+ if type(c) ~= "string" then return nil end
+ local e = #c + i
+ if s:sub(i, e - 1) == c then return e else return nil end
+end
+
+
+local S = (Predef.space + "--" * (any - Predef.nl)^0)^0
+
+local name = m.C(m.R("AZ", "az", "__") * m.R("AZ", "az", "__", "09")^0)
+
+local arrow = S * "<-"
+
+-- a defined name only have meaning in a given environment
+local Def = name * m.Carg(1)
+
+local num = m.C(m.R"09"^1) * S / tonumber
+
+local String = "'" * m.C((any - "'" - m.P"\n")^0) * expect("'", "MisTerm1")
+ + '"' * m.C((any - '"' - m.P"\n")^0) * expect('"', "MisTerm2")
+
+
+local defined = "%" * Def / function (c,Defs)
+ local cat = Defs and Defs[c] or Predef[c]
+ if not cat then
+ error("name '" .. c .. "' undefined")
+ end
+ return cat
+end
+
+local Range = m.Cs(any * (m.P"-"/"") * (any - "]")) / mm.R
+
+local item = defined + Range + m.C(any - m.P"\n")
+
+local Class =
+ "["
+ * (m.C(m.P"^"^-1)) -- optional complement symbol
+ * m.Cf(expect(item, "ExpItem") * (item - "]")^0, mt.__add)
+ / function (c, p) return c == "^" and any - p or p end
+ * expect("]", "MisClose8")
+
+local function adddef (t, k, exp)
+ if t[k] then
+ -- TODO 改了一下这里的代码,重复定义不会抛错
+ --error("'"..k.."' already defined as a rule")
+ else
+ t[k] = exp
+ end
+ return t
+end
+
+local function firstdef (n, r) return adddef({n}, n, r) end
+
+
+local function NT (n, b)
+ if not b then
+ error("rule '"..n.."' used outside a grammar")
+ else return mm.V(n)
+ end
+end
+
+
+local exp = m.P{ "Exp",
+ Exp = S * ( m.V"Grammar"
+ + m.Cf(m.V"Seq" * (S * "/" * expect(S * m.V"Seq", "ExpPatt1"))^0, mt.__add) );
+ Seq = m.Cf(m.Cc(m.P"") * m.V"Prefix" * (S * m.V"Prefix")^0, mt.__mul);
+ Prefix = "&" * expect(S * m.V"Prefix", "ExpPatt2") / mt.__len
+ + "!" * expect(S * m.V"Prefix", "ExpPatt3") / mt.__unm
+ + m.V"Suffix";
+ Suffix = m.Cf(m.V"Primary" *
+ ( S * ( m.P"+" * m.Cc(1, mt.__pow)
+ + m.P"*" * m.Cc(0, mt.__pow)
+ + m.P"?" * m.Cc(-1, mt.__pow)
+ + "^" * expect( m.Cg(num * m.Cc(mult))
+ + m.Cg(m.C(m.S"+-" * m.R"09"^1) * m.Cc(mt.__pow)
+ + name * m.Cc"lab"
+ ),
+ "ExpNumName")
+ + "->" * expect(S * ( m.Cg((String + num) * m.Cc(mt.__div))
+ + m.P"{}" * m.Cc(nil, m.Ct)
+ + m.Cg(Def / getdef * m.Cc(mt.__div))
+ ),
+ "ExpCap")
+ + "=>" * expect(S * m.Cg(Def / getdef * m.Cc(m.Cmt)),
+ "ExpName1")
+ )
+ )^0, function (a,b,f) if f == "lab" then return a + mm.T(b) else return f(a,b) end end );
+ Primary = "(" * expect(m.V"Exp", "ExpPatt4") * expect(S * ")", "MisClose1")
+ + String / mm.P
+ + Class
+ + defined
+ + "%" * expect(m.P"{", "ExpNameOrLab")
+ * expect(S * m.V"Label", "ExpLab1")
+ * expect(S * "}", "MisClose7") / mm.T
+ + "{:" * (name * ":" + m.Cc(nil)) * expect(m.V"Exp", "ExpPatt5")
+ * expect(S * ":}", "MisClose2")
+ / function (n, p) return mm.Cg(p, n) end
+ + "=" * expect(name, "ExpName2")
+ / function (n) return mm.Cmt(mm.Cb(n), equalcap) end
+ + m.P"{}" / mm.Cp
+ + "{~" * expect(m.V"Exp", "ExpPatt6")
+ * expect(S * "~}", "MisClose3") / mm.Cs
+ + "{|" * expect(m.V"Exp", "ExpPatt7")
+ * expect(S * "|}", "MisClose4") / mm.Ct
+ + "{" * expect(m.V"Exp", "ExpPattOrClose")
+ * expect(S * "}", "MisClose5") / mm.C
+ + m.P"." * m.Cc(any)
+ + (name * -arrow + "<" * expect(name, "ExpName3")
+ * expect(">", "MisClose6")) * m.Cb("G") / NT;
+ Label = num + name;
+ Definition = name * arrow * expect(m.V"Exp", "ExpPatt8");
+ Grammar = m.Cg(m.Cc(true), "G")
+ * m.Cf(m.V"Definition" / firstdef * (S * m.Cg(m.V"Definition"))^0,
+ adddef) / mm.P;
+}
+
+local pattern = S * m.Cg(m.Cc(false), "G") * expect(exp, "NoPatt") / mm.P
+ * S * expect(-any, "ExtraChars")
+
+local function lineno (s, i)
+ if i == 1 then return 1, 1 end
+ local adjustment = 0
+ -- report the current line if at end of line, not the next
+ if s:sub(i,i) == '\n' then
+ i = i-1
+ adjustment = 1
+ end
+ local rest, num = s:sub(1,i):gsub("[^\n]*\n", "")
+ local r = #rest
+ return 1 + num, (r ~= 0 and r or 1) + adjustment
+end
+
+local function calcline (s, i)
+ if i == 1 then return 1, 1 end
+ local rest, line = s:sub(1,i):gsub("[^\n]*\n", "")
+ local col = #rest
+ return 1 + line, col ~= 0 and col or 1
+end
+
+
+local function splitlines(str)
+ local t = {}
+ local function helper(line) tinsert(t, line) return "" end
+ helper((str:gsub("(.-)\r?\n", helper)))
+ return t
+end
+
+local function compile (p, defs)
+ if mm.type(p) == "pattern" then return p end -- already compiled
+ p = p .. " " -- for better reporting of column numbers in errors when at EOF
+ local ok, cp, label, poserr = pcall(function() return pattern:match(p, 1, defs) end)
+ if not ok and cp then
+ if type(cp) == "string" then
+ cp = cp:gsub("^[^:]+:[^:]+: ", "")
+ end
+ error(cp, 3)
+ end
+ if not cp then
+ local lines = splitlines(p)
+ local line, col = lineno(p, poserr)
+ local err = {}
+ tinsert(err, "L" .. line .. ":C" .. col .. ": " .. errinfo[label])
+ tinsert(err, lines[line])
+ tinsert(err, rep(" ", col-1) .. "^")
+ error("syntax error(s) in pattern\n" .. concat(err, "\n"), 3)
+ end
+ return cp
+end
+
+local function match (s, p, i)
+ local cp = mem[p]
+ if not cp then
+ cp = compile(p)
+ mem[p] = cp
+ end
+ return cp:match(s, i or 1)
+end
+
+local function find (s, p, i)
+ local cp = fmem[p]
+ if not cp then
+ cp = compile(p) / 0
+ cp = mm.P{ mm.Cp() * cp * mm.Cp() + 1 * mm.V(1) }
+ fmem[p] = cp
+ end
+ local i, e = cp:match(s, i or 1)
+ if i then return i, e - 1
+ else return i
+ end
+end
+
+local function gsub (s, p, rep)
+ local g = gmem[p] or {} -- ensure gmem[p] is not collected while here
+ gmem[p] = g
+ local cp = g[rep]
+ if not cp then
+ cp = compile(p)
+ cp = mm.Cs((cp / rep + 1)^0)
+ g[rep] = cp
+ end
+ return cp:match(s)
+end
+
+
+-- exported names
+local re = {
+ compile = compile,
+ match = match,
+ find = find,
+ gsub = gsub,
+ updatelocale = updatelocale,
+ calcline = calcline
+}
+
+return re
diff --git a/script/parser/split.lua b/script/parser/split.lua
new file mode 100644
index 00000000..6ce4a4e7
--- /dev/null
+++ b/script/parser/split.lua
@@ -0,0 +1,9 @@
+local m = require 'lpeglabel'
+
+local NL = m.P'\r\n' + m.S'\r\n'
+local LINE = m.C(1 - NL)
+
+return function (str)
+ local MATCH = m.Ct((LINE * NL)^0 * LINE)
+ return MATCH:match(str)
+end
diff --git a/script/proto/define.lua b/script/proto/define.lua
new file mode 100644
index 00000000..966a5161
--- /dev/null
+++ b/script/proto/define.lua
@@ -0,0 +1,287 @@
+local guide = require 'parser.guide'
+local util = require 'utility'
+
+local m = {}
+
+--- 获取 position 对应的光标位置
+---@param lines table
+---@param text string
+---@param position position
+---@return integer
+function m.offset(lines, text, position)
+ local row = position.line + 1
+ local start = guide.lineRange(lines, row)
+ if start <= 0 or start > #text then
+ return #text + 1
+ end
+ local offset = utf8.offset(text, position.character + 1, start)
+ return offset - 1
+end
+
+--- 获取 position 对应的光标位置(根据附近的单词)
+---@param lines table
+---@param text string
+---@param position position
+---@return integer
+function m.offsetOfWord(lines, text, position)
+ local row = position.line + 1
+ local start = guide.lineRange(lines, row)
+ if start <= 0 or start > #text then
+ return #text + 1
+ end
+ local offset = utf8.offset(text, position.character + 1, start)
+ if offset > #text
+ or text:sub(offset-1, offset):match '[%w_][^%w_]' then
+ offset = offset - 1
+ end
+ return offset
+end
+
+--- 将光标位置转化为 position
+---@alias position table
+---@param lines table
+---@param text string
+---@param offset integer
+---@return position
+function m.position(lines, text, offset)
+ local row, col = guide.positionOf(lines, offset)
+ local start = guide.lineRange(lines, row)
+ if start < 1 then
+ start = 1
+ end
+ local ucol = util.utf8Len(text, start, start + col - 1)
+ if row < 1 then
+ row = 1
+ end
+ return {
+ line = row - 1,
+ character = ucol,
+ }
+end
+
+--- 将起点与终点位置转化为 range
+---@alias range table
+---@param lines table
+---@param text string
+---@param offset1 integer
+---@param offset2 integer
+function m.range(lines, text, offset1, offset2)
+ local range = {
+ start = m.position(lines, text, offset1),
+ ['end'] = m.position(lines, text, offset2),
+ }
+ if range.start.character > 0 then
+ range.start.character = range.start.character - 1
+ end
+ return range
+end
+
+---@alias location table
+---@param uri string
+---@param range range
+---@return location
+function m.location(uri, range)
+ return {
+ uri = uri,
+ range = range,
+ }
+end
+
+---@alias locationLink table
+---@param uri string
+---@param range range
+---@param selection range
+---@param origin range
+function m.locationLink(uri, range, selection, origin)
+ return {
+ targetUri = uri,
+ targetRange = range,
+ targetSelectionRange = selection,
+ originSelectionRange = origin,
+ }
+end
+
+function m.textEdit(range, newtext)
+ return {
+ range = range,
+ newText = newtext,
+ }
+end
+
+--- 诊断等级
+m.DiagnosticSeverity = {
+ Error = 1,
+ Warning = 2,
+ Information = 3,
+ Hint = 4,
+}
+
+--- 诊断类型与默认等级
+m.DiagnosticDefaultSeverity = {
+ ['unused-local'] = 'Hint',
+ ['unused-function'] = 'Hint',
+ ['undefined-global'] = 'Warning',
+ ['global-in-nil-env'] = 'Warning',
+ ['unused-label'] = 'Hint',
+ ['unused-vararg'] = 'Hint',
+ ['trailing-space'] = 'Hint',
+ ['redefined-local'] = 'Hint',
+ ['newline-call'] = 'Information',
+ ['newfield-call'] = 'Warning',
+ ['redundant-parameter'] = 'Hint',
+ ['ambiguity-1'] = 'Warning',
+ ['lowercase-global'] = 'Information',
+ ['undefined-env-child'] = 'Information',
+ ['duplicate-index'] = 'Warning',
+ ['empty-block'] = 'Hint',
+ ['redundant-value'] = 'Hint',
+ ['code-after-break'] = 'Hint',
+
+ ['duplicate-doc-class'] = 'Warning',
+ ['undefined-doc-class'] = 'Warning',
+ ['undefined-doc-name'] = 'Warning',
+ ['circle-doc-class'] = 'Warning',
+ ['undefined-doc-param'] = 'Warning',
+ ['duplicate-doc-param'] = 'Warning',
+ ['doc-field-no-class'] = 'Warning',
+ ['duplicate-doc-field'] = 'Warning',
+}
+
+--- 诊断报告标签
+m.DiagnosticTag = {
+ Unnecessary = 1,
+ Deprecated = 2,
+}
+
+m.DocumentHighlightKind = {
+ Text = 1,
+ Read = 2,
+ Write = 3,
+}
+
+m.MessageType = {
+ Error = 1,
+ Warning = 2,
+ Info = 3,
+ Log = 4,
+}
+
+m.FileChangeType = {
+ Created = 1,
+ Changed = 2,
+ Deleted = 3,
+}
+
+m.CompletionItemKind = {
+ Text = 1,
+ Method = 2,
+ Function = 3,
+ Constructor = 4,
+ Field = 5,
+ Variable = 6,
+ Class = 7,
+ Interface = 8,
+ Module = 9,
+ Property = 10,
+ Unit = 11,
+ Value = 12,
+ Enum = 13,
+ Keyword = 14,
+ Snippet = 15,
+ Color = 16,
+ File = 17,
+ Reference = 18,
+ Folder = 19,
+ EnumMember = 20,
+ Constant = 21,
+ Struct = 22,
+ Event = 23,
+ Operator = 24,
+ TypeParameter = 25,
+}
+
+m.DiagnosticSeverity = {
+ Error = 1,
+ Warning = 2,
+ Information = 3,
+ Hint = 4,
+}
+
+m.ErrorCodes = {
+ -- Defined by JSON RPC
+ ParseError = -32700,
+ InvalidRequest = -32600,
+ MethodNotFound = -32601,
+ InvalidParams = -32602,
+ InternalError = -32603,
+ serverErrorStart = -32099,
+ serverErrorEnd = -32000,
+ ServerNotInitialized = -32002,
+ UnknownErrorCode = -32001,
+
+ -- Defined by the protocol.
+ RequestCancelled = -32800,
+}
+
+m.SymbolKind = {
+ File = 1,
+ Module = 2,
+ Namespace = 3,
+ Package = 4,
+ Class = 5,
+ Method = 6,
+ Property = 7,
+ Field = 8,
+ Constructor = 9,
+ Enum = 10,
+ Interface = 11,
+ Function = 12,
+ Variable = 13,
+ Constant = 14,
+ String = 15,
+ Number = 16,
+ Boolean = 17,
+ Array = 18,
+ Object = 19,
+ Key = 20,
+ Null = 21,
+ EnumMember = 22,
+ Struct = 23,
+ Event = 24,
+ Operator = 25,
+ TypeParameter = 26,
+}
+
+m.TokenModifiers = {
+ ["declaration"] = 1 << 0,
+ ["documentation"] = 1 << 1,
+ ["static"] = 1 << 2,
+ ["abstract"] = 1 << 3,
+ ["deprecated"] = 1 << 4,
+ ["readonly"] = 1 << 5,
+}
+
+m.TokenTypes = {
+ ["comment"] = 0,
+ ["keyword"] = 1,
+ ["number"] = 2,
+ ["regexp"] = 3,
+ ["operator"] = 4,
+ ["namespace"] = 5,
+ ["type"] = 6,
+ ["struct"] = 7,
+ ["class"] = 8,
+ ["interface"] = 9,
+ ["enum"] = 10,
+ ["typeParameter"] = 11,
+ ["function"] = 12,
+ ["member"] = 13,
+ ["macro"] = 14,
+ ["variable"] = 15,
+ ["parameter"] = 16,
+ ["property"] = 17,
+ ["label"] = 18,
+}
+
+
+return m
diff --git a/script/proto/init.lua b/script/proto/init.lua
new file mode 100644
index 00000000..33e637f6
--- /dev/null
+++ b/script/proto/init.lua
@@ -0,0 +1,3 @@
+local proto = require 'proto.proto'
+
+return proto
diff --git a/script/proto/proto.lua b/script/proto/proto.lua
new file mode 100644
index 00000000..d8538d8f
--- /dev/null
+++ b/script/proto/proto.lua
@@ -0,0 +1,147 @@
+local subprocess = require 'bee.subprocess'
+local util = require 'utility'
+local await = require 'await'
+local pub = require 'pub'
+local jsonrpc = require 'jsonrpc'
+local define = require 'proto.define'
+local timer = require 'timer'
+local json = require 'json'
+
+local reqCounter = util.counter()
+
+local m = {}
+
+m.ability = {}
+m.waiting = {}
+m.holdon = {}
+
+function m.getMethodName(proto)
+ if proto.method:sub(1, 2) == '$/' then
+ return proto.method:sub(3), true
+ else
+ return proto.method, false
+ end
+end
+
+function m.on(method, callback)
+ m.ability[method] = callback
+end
+
+function m.response(id, res)
+ if id == nil then
+ log.error('Response id is nil!', util.dump(res))
+ return
+ end
+ assert(m.holdon[id])
+ m.holdon[id] = nil
+ local data = {}
+ data.id = id
+ data.result = res == nil and json.null or res
+ local buf = jsonrpc.encode(data)
+ --log.debug('Response', id, #buf)
+ io.stdout:write(buf)
+end
+
+function m.responseErr(id, code, message)
+ if id == nil then
+ log.error('Response id is nil!', util.dump(message))
+ return
+ end
+ local buf = jsonrpc.encode {
+ id = id,
+ error = {
+ code = code,
+ message = message,
+ }
+ }
+ --log.debug('ResponseErr', id, #buf)
+ io.stdout:write(buf)
+end
+
+function m.notify(name, params)
+ local buf = jsonrpc.encode {
+ method = name,
+ params = params,
+ }
+ --log.debug('Notify', name, #buf)
+ io.stdout:write(buf)
+end
+
+function m.awaitRequest(name, params)
+ local id = reqCounter()
+ local buf = jsonrpc.encode {
+ id = id,
+ method = name,
+ params = params,
+ }
+ --log.debug('Request', name, #buf)
+ io.stdout:write(buf)
+ return await.wait(function (waker)
+ m.waiting[id] = waker
+ end)
+end
+
+function m.doMethod(proto)
+ local method, optional = m.getMethodName(proto)
+ local abil = m.ability[method]
+ if not abil then
+ if not optional then
+ log.warn('Recieved unknown proto: ' .. method)
+ end
+ if proto.id then
+ m.responseErr(proto.id, define.ErrorCodes.MethodNotFound, method)
+ end
+ return
+ end
+ if proto.id then
+ m.holdon[proto.id] = method
+ end
+ await.call(function ()
+ --log.debug('Start method:', method)
+ local clock = os.clock()
+ local ok = true
+ local res
+ -- 任务可能在执行过程中被中断,通过close来捕获
+ local response <close> = util.defer(function ()
+ local passed = os.clock() - clock
+ if passed > 0.2 then
+ log.debug(('Method [%s] takes [%.3f]sec.'):format(method, passed))
+ end
+ --log.debug('Finish method:', method)
+ if not proto.id then
+ return
+ end
+ if ok then
+ m.response(proto.id, res)
+ else
+ m.responseErr(proto.id, define.ErrorCodes.InternalError, res)
+ end
+ end)
+ ok, res = xpcall(abil, log.error, proto.params)
+ end)
+end
+
+function m.doResponse(proto)
+ local id = proto.id
+ local waker = m.waiting[id]
+ if not waker then
+ log.warn('Response id not found: ' .. util.dump(proto))
+ return
+ end
+ m.waiting[id] = nil
+ if proto.error then
+ log.warn(('Response error [%d]: %s'):format(proto.error.code, proto.error.message))
+ return
+ end
+ waker(proto.result)
+end
+
+function m.listen()
+ subprocess.filemode(io.stdin, 'b')
+ subprocess.filemode(io.stdout, 'b')
+ io.stdin:setvbuf 'no'
+ io.stdout:setvbuf 'no'
+ pub.task('loadProto')
+end
+
+return m
diff --git a/script/provider/capability.lua b/script/provider/capability.lua
new file mode 100644
index 00000000..23ec27b0
--- /dev/null
+++ b/script/provider/capability.lua
@@ -0,0 +1,61 @@
+local sp = require 'bee.subprocess'
+local nonil = require 'without-check-nil'
+local client = require 'provider.client'
+
+local m = {}
+
+local function allWords()
+ local str = [[abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.:('"[,#*@| ]]
+ local list = {}
+ for c in str:gmatch '.' do
+ list[#list+1] = c
+ end
+ return list
+end
+
+function m.getIniter()
+ local initer = {
+ -- 文本同步方式
+ textDocumentSync = {
+ -- 打开关闭文本时通知
+ openClose = true,
+ -- 文本改变时完全通知 TODO 支持差量更新(2)
+ change = 1,
+ },
+
+ hoverProvider = true,
+ definitionProvider = true,
+ referencesProvider = true,
+ renameProvider = {
+ prepareProvider = true,
+ },
+ documentSymbolProvider = true,
+ workspaceSymbolProvider = true,
+ documentHighlightProvider = true,
+ codeActionProvider = true,
+ signatureHelpProvider = {
+ triggerCharacters = { '(', ',' },
+ },
+ executeCommandProvider = {
+ commands = {
+ 'lua.removeSpace:' .. sp:get_id(),
+ 'lua.solve:' .. sp:get_id(),
+ },
+ }
+ --documentOnTypeFormattingProvider = {
+ -- firstTriggerCharacter = '}',
+ --},
+ }
+
+ nonil.enable()
+ if not client.info.capabilities.textDocument.completion.dynamicRegistration then
+ initer.completionProvider = {
+ triggerCharacters = allWords(),
+ }
+ end
+ nonil.disable()
+
+ return initer
+end
+
+return m
diff --git a/script/provider/client.lua b/script/provider/client.lua
new file mode 100644
index 00000000..c1b16f0f
--- /dev/null
+++ b/script/provider/client.lua
@@ -0,0 +1,18 @@
+local nonil = require 'without-check-nil'
+local util = require 'utility'
+
+local m = {}
+
+function m.client()
+ nonil.enable()
+ local name = m.info.clientInfo.name
+ nonil.disable()
+ return name
+end
+
+function m.init(t)
+ log.debug('Client init', util.dump(t))
+ m.info = t
+end
+
+return m
diff --git a/script/provider/completion.lua b/script/provider/completion.lua
new file mode 100644
index 00000000..e506cd7b
--- /dev/null
+++ b/script/provider/completion.lua
@@ -0,0 +1,54 @@
+local proto = require 'proto'
+
+local isEnable = false
+
+local function allWords()
+ local str = [[abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.:('"[,#*@| ]]
+ local list = {}
+ for c in str:gmatch '.' do
+ list[#list+1] = c
+ end
+ return list
+end
+
+local function enable()
+ -- TODO 检查客户端是否支持动态注册自动完成
+ if isEnable then
+ return
+ end
+ isEnable = true
+ log.debug('Enable completion.')
+ proto.awaitRequest('client/registerCapability', {
+ registrations = {
+ {
+ id = 'completion',
+ method = 'textDocument/completion',
+ registerOptions = {
+ resolveProvider = true,
+ triggerCharacters = allWords(),
+ },
+ },
+ }
+ })
+end
+
+local function disable()
+ if not isEnable then
+ return
+ end
+ isEnable = false
+ log.debug('Disable completion.')
+ proto.awaitRequest('client/unregisterCapability', {
+ unregisterations = {
+ {
+ id = 'completion',
+ method = 'textDocument/completion',
+ },
+ }
+ })
+end
+
+return {
+ enable = enable,
+ disable = disable,
+}
diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua
new file mode 100644
index 00000000..845b9f44
--- /dev/null
+++ b/script/provider/diagnostic.lua
@@ -0,0 +1,303 @@
+local await = require 'await'
+local proto = require 'proto.proto'
+local define = require 'proto.define'
+local lang = require 'language'
+local files = require 'files'
+local config = require 'config'
+local core = require 'core.diagnostics'
+local util = require 'utility'
+local ws = require 'workspace'
+
+local m = {}
+m._start = false
+m.cache = {}
+m.sleepRest = 0.0
+
+local function concat(t, sep)
+ if type(t) ~= 'table' then
+ return t
+ end
+ return table.concat(t, sep)
+end
+
+local function buildSyntaxError(uri, err)
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local message = lang.script('PARSER_'..err.type, err.info)
+
+ if err.version then
+ local version = err.info and err.info.version or config.config.runtime.version
+ message = message .. ('(%s)'):format(lang.script('DIAG_NEED_VERSION'
+ , concat(err.version, '/')
+ , version
+ ))
+ end
+
+ local related = err.info and err.info.related
+ local relatedInformation
+ if related then
+ relatedInformation = {}
+ for _, rel in ipairs(related) do
+ local rmessage
+ if rel.message then
+ rmessage = lang.script('PARSER_'..rel.message)
+ else
+ rmessage = text:sub(rel.start, rel.finish)
+ end
+ relatedInformation[#relatedInformation+1] = {
+ message = rmessage,
+ location = define.location(uri, define.range(lines, text, rel.start, rel.finish)),
+ }
+ end
+ end
+
+ return {
+ code = err.type:lower():gsub('_', '-'),
+ range = define.range(lines, text, err.start, err.finish),
+ severity = define.DiagnosticSeverity.Error,
+ source = lang.script.DIAG_SYNTAX_CHECK,
+ message = message,
+ relatedInformation = relatedInformation,
+ }
+end
+
+local function buildDiagnostic(uri, diag)
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+
+ local relatedInformation
+ if diag.related then
+ relatedInformation = {}
+ for _, rel in ipairs(diag.related) do
+ local rtext = files.getText(rel.uri)
+ local rlines = files.getLines(rel.uri)
+ relatedInformation[#relatedInformation+1] = {
+ message = rel.message or rtext:sub(rel.start, rel.finish),
+ location = define.location(rel.uri, define.range(rlines, rtext, rel.start, rel.finish))
+ }
+ end
+ end
+
+ return {
+ range = define.range(lines, text, diag.start, diag.finish),
+ source = lang.script.DIAG_DIAGNOSTICS,
+ severity = diag.level,
+ message = diag.message,
+ code = diag.code,
+ tags = diag.tags,
+ relatedInformation = relatedInformation,
+ }
+end
+
+local function merge(a, b)
+ if not a and not b then
+ return nil
+ end
+ local t = {}
+ if a then
+ for i = 1, #a do
+ t[#t+1] = a[i]
+ end
+ end
+ if b then
+ for i = 1, #b do
+ t[#t+1] = b[i]
+ end
+ end
+ return t
+end
+
+function m.clear(uri)
+ local luri = uri:lower()
+ if not m.cache[luri] then
+ return
+ end
+ m.cache[luri] = nil
+ proto.notify('textDocument/publishDiagnostics', {
+ uri = files.getOriginUri(luri) or uri,
+ diagnostics = {},
+ })
+end
+
+function m.clearAll()
+ for luri in pairs(m.cache) do
+ m.clear(luri)
+ end
+end
+
+function m.syntaxErrors(uri, ast)
+ if #ast.errs == 0 then
+ return nil
+ end
+
+ local results = {}
+
+ for _, err in ipairs(ast.errs) do
+ results[#results+1] = buildSyntaxError(uri, err)
+ end
+
+ return results
+end
+
+function m.diagnostics(uri, diags)
+ if not m._start then
+ return
+ end
+
+ core(uri, function (results)
+ if #results == 0 then
+ return
+ end
+ for i = 1, #results do
+ diags[#diags+1] = buildDiagnostic(uri, results[i])
+ end
+ end)
+end
+
+function m.doDiagnostic(uri)
+ if not config.config.diagnostics.enable then
+ return
+ end
+ uri = uri:lower()
+ if files.isLibrary(uri) then
+ return
+ end
+
+ await.delay()
+
+ local ast = files.getAst(uri)
+ if not ast then
+ m.clear(uri)
+ return
+ end
+
+ local syntax = m.syntaxErrors(uri, ast)
+ local diags = {}
+
+ local function pushResult()
+ local full = merge(syntax, diags)
+ if not full then
+ m.clear(uri)
+ return
+ end
+
+ if util.equal(m.cache, full) then
+ return
+ end
+ m.cache[uri] = full
+
+ proto.notify('textDocument/publishDiagnostics', {
+ uri = files.getOriginUri(uri),
+ diagnostics = full,
+ })
+ end
+
+ if await.hasID 'diagnosticsAll' then
+ m.checkStepResult = nil
+ else
+ local clock = os.clock()
+ m.checkStepResult = function ()
+ if os.clock() - clock >= 0.2 then
+ pushResult()
+ clock = os.clock()
+ end
+ end
+ end
+
+ m.diagnostics(uri, diags)
+ pushResult()
+end
+
+function m.refresh(uri)
+ if not m._start then
+ return
+ end
+ await.call(function ()
+ if uri then
+ m.doDiagnostic(uri)
+ end
+ m.diagnosticsAll()
+ end, 'files.version')
+end
+
+function m.diagnosticsAll()
+ if not config.config.diagnostics.enable then
+ m.clearAll()
+ return
+ end
+ if not m._start then
+ return
+ end
+ local delay = config.config.diagnostics.workspaceDelay / 1000
+ if delay < 0 then
+ return
+ end
+ await.close 'diagnosticsAll'
+ await.call(function ()
+ await.sleep(delay)
+ m.diagnosticsAllClock = os.clock()
+ local clock = os.clock()
+ for uri in files.eachFile() do
+ m.doDiagnostic(uri)
+ await.delay()
+ end
+ log.debug('全文诊断耗时:', os.clock() - clock)
+ end, 'files.version', 'diagnosticsAll')
+end
+
+function m.start()
+ m._start = true
+ m.diagnosticsAll()
+end
+
+function m.checkStepResult()
+ if await.hasID 'diagnosticsAll' then
+ return
+ end
+end
+
+function m.checkWorkspaceDiag()
+ if not await.hasID 'diagnosticsAll' then
+ return
+ end
+ local speedRate = config.config.diagnostics.workspaceRate
+ if speedRate <= 0 or speedRate >= 100 then
+ return
+ end
+ local currentClock = os.clock()
+ local passed = currentClock - m.diagnosticsAllClock
+ local sleepTime = passed * (100 - speedRate) / speedRate + m.sleepRest
+ m.sleepRest = 0.0
+ if sleepTime < 0.001 then
+ m.sleepRest = m.sleepRest + sleepTime
+ return
+ end
+ if sleepTime > 0.1 then
+ m.sleepRest = sleepTime - 0.1
+ sleepTime = 0.1
+ end
+ await.sleep(sleepTime)
+ m.diagnosticsAllClock = os.clock()
+ return false
+end
+
+files.watch(function (ev, uri)
+ if ev == 'remove' then
+ m.clear(uri)
+ elseif ev == 'update' then
+ m.refresh(uri)
+ elseif ev == 'open' then
+ m.doDiagnostic(uri)
+ end
+end)
+
+await.watch(function (ev, co)
+ if ev == 'delay' then
+ if m.checkStepResult then
+ m.checkStepResult()
+ end
+ return m.checkWorkspaceDiag()
+ end
+end)
+
+return m
diff --git a/script/provider/init.lua b/script/provider/init.lua
new file mode 100644
index 00000000..7eafb70a
--- /dev/null
+++ b/script/provider/init.lua
@@ -0,0 +1 @@
+require 'provider.provider'
diff --git a/script/provider/markdown.lua b/script/provider/markdown.lua
new file mode 100644
index 00000000..ca76ec89
--- /dev/null
+++ b/script/provider/markdown.lua
@@ -0,0 +1,26 @@
+local mt = {}
+mt.__index = mt
+mt.__name = 'markdown'
+
+function mt:add(language, text)
+ if not text or #text == 0 then
+ return
+ end
+ if language == 'md' then
+ if self._last == 'md' then
+ self[#self+1] = ''
+ end
+ self[#self+1] = text
+ else
+ self[#self+1] = ('```%s\n%s\n```'):format(language, text)
+ end
+ self._last = language
+end
+
+function mt:string()
+ return table.concat(self, '\n')
+end
+
+return function ()
+ return setmetatable({}, mt)
+end
diff --git a/script/provider/provider.lua b/script/provider/provider.lua
new file mode 100644
index 00000000..3508116f
--- /dev/null
+++ b/script/provider/provider.lua
@@ -0,0 +1,642 @@
+local util = require 'utility'
+local cap = require 'provider.capability'
+local completion= require 'provider.completion'
+local semantic = require 'provider.semantic-tokens'
+local await = require 'await'
+local files = require 'files'
+local proto = require 'proto.proto'
+local define = require 'proto.define'
+local workspace = require 'workspace'
+local config = require 'config'
+local library = require 'library'
+local markdown = require 'provider.markdown'
+local client = require 'provider.client'
+local furi = require 'file-uri'
+local pub = require 'pub'
+local fs = require 'bee.filesystem'
+local lang = require 'language'
+
+local function updateConfig()
+ local diagnostics = require 'provider.diagnostic'
+ local vm = require 'vm'
+ local configs = proto.awaitRequest('workspace/configuration', {
+ items = {
+ {
+ scopeUri = workspace.uri,
+ section = 'Lua',
+ },
+ {
+ scopeUri = workspace.uri,
+ section = 'files.associations',
+ },
+ {
+ scopeUri = workspace.uri,
+ section = 'files.exclude',
+ }
+ },
+ })
+
+ local updated = configs[1]
+ local other = {
+ associations = configs[2],
+ exclude = configs[3],
+ }
+
+ local oldConfig = util.deepCopy(config.config)
+ local oldOther = util.deepCopy(config.other)
+ config.setConfig(updated, other)
+ local newConfig = config.config
+ local newOther = config.other
+ if not util.equal(oldConfig.runtime, newConfig.runtime) then
+ library.init()
+ workspace.reload()
+ end
+ if not util.equal(oldConfig.diagnostics, newConfig.diagnostics) then
+ diagnostics.diagnosticsAll()
+ end
+ if not util.equal(oldConfig.plugin, newConfig.plugin) then
+ end
+ if not util.equal(oldConfig.workspace, newConfig.workspace)
+ or not util.equal(oldConfig.plugin, newConfig.plugin)
+ or not util.equal(oldOther.associations, newOther.associations)
+ or not util.equal(oldOther.exclude, newOther.exclude)
+ then
+ workspace.reload()
+ end
+ if not util.equal(oldConfig.luadoc, newConfig.luadoc) then
+ files.flushCache()
+ end
+ if not util.equal(oldConfig.intelliSense, newConfig.intelliSense) then
+ files.flushCache()
+ end
+
+ if newConfig.completion.enable then
+ completion.enable()
+ else
+ completion.disable()
+ end
+ if newConfig.color.mode == 'Semantic' then
+ semantic.enable()
+ else
+ semantic.disable()
+ end
+end
+
+proto.on('initialize', function (params)
+ client.init(params)
+ library.init()
+ workspace.init(params.rootUri)
+ return {
+ capabilities = cap.getIniter(),
+ serverInfo = {
+ name = 'sumneko.lua',
+ },
+ }
+end)
+
+proto.on('initialized', function (params)
+ updateConfig()
+ proto.awaitRequest('client/registerCapability', {
+ registrations = {
+ -- 监视文件变化
+ {
+ id = '0',
+ method = 'workspace/didChangeWatchedFiles',
+ registerOptions = {
+ watchers = {
+ {
+ globPattern = '**/',
+ kind = 1 | 2 | 4,
+ }
+ },
+ },
+ },
+ -- 配置变化
+ {
+ id = '1',
+ method = 'workspace/didChangeConfiguration',
+ }
+ }
+ })
+ await.call(workspace.awaitPreload)
+ return true
+end)
+
+proto.on('exit', function ()
+ log.info('Server exited.')
+ os.exit(true)
+end)
+
+proto.on('shutdown', function ()
+ log.info('Server shutdown.')
+ return true
+end)
+
+proto.on('workspace/didChangeConfiguration', function ()
+ updateConfig()
+end)
+
+proto.on('workspace/didChangeWatchedFiles', function (params)
+ for _, change in ipairs(params.changes) do
+ local uri = change.uri
+ -- TODO 创建文件与删除文件直接重新扫描(文件改名、文件夹删除等情况太复杂了)
+ if change.type == define.FileChangeType.Created
+ or change.type == define.FileChangeType.Deleted then
+ workspace.reload()
+ break
+ elseif change.type == define.FileChangeType.Changed then
+ -- 如果文件处于关闭状态,则立即更新;否则等待didChange协议来更新
+ if files.isLua(uri) and not files.isOpen(uri) then
+ files.setText(uri, pub.awaitTask('loadFile', uri))
+ else
+ local path = furi.decode(uri)
+ local filename = fs.path(path):filename():string()
+ -- 排除类文件发生更改需要重新扫描
+ if files.eq(filename, '.gitignore')
+ or files.eq(filename, '.gitmodules') then
+ workspace.reload()
+ break
+ end
+ end
+ end
+ end
+end)
+
+proto.on('textDocument/didOpen', function (params)
+ local doc = params.textDocument
+ local uri = doc.uri
+ local text = doc.text
+ files.open(uri)
+ files.setText(uri, text)
+end)
+
+proto.on('textDocument/didClose', function (params)
+ local doc = params.textDocument
+ local uri = doc.uri
+ files.close(uri)
+ if not files.isLua(uri) then
+ files.remove(uri)
+ end
+end)
+
+proto.on('textDocument/didChange', function (params)
+ local doc = params.textDocument
+ local change = params.contentChanges
+ local uri = doc.uri
+ local text = change[1].text
+ if files.isLua(uri) or files.isOpen(uri) then
+ --log.debug('didChange:', uri)
+ files.setText(uri, text)
+ --log.debug('setText:', #text)
+ end
+end)
+
+proto.on('textDocument/hover', function (params)
+ await.close 'hover'
+ await.setID 'hover'
+ local core = require 'core.hover'
+ local doc = params.textDocument
+ local uri = doc.uri
+ if not files.exists(uri) then
+ return nil
+ end
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local offset = define.offsetOfWord(lines, text, params.position)
+ local hover = core.byUri(uri, offset)
+ if not hover then
+ return nil
+ end
+ local md = markdown()
+ md:add('lua', hover.label)
+ md:add('md', hover.description)
+ return {
+ contents = {
+ value = md:string(),
+ kind = 'markdown',
+ },
+ range = define.range(lines, text, hover.source.start, hover.source.finish),
+ }
+end)
+
+proto.on('textDocument/definition', function (params)
+ local core = require 'core.definition'
+ local uri = params.textDocument.uri
+ if not files.exists(uri) then
+ return nil
+ end
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local offset = define.offsetOfWord(lines, text, params.position)
+ local result = core(uri, offset)
+ if not result then
+ return nil
+ end
+ local response = {}
+ for i, info in ipairs(result) do
+ local targetUri = info.uri
+ if targetUri then
+ local targetLines = files.getLines(targetUri)
+ local targetText = files.getText(targetUri)
+ response[i] = define.locationLink(targetUri
+ , define.range(targetLines, targetText, info.target.start, info.target.finish)
+ , define.range(targetLines, targetText, info.target.start, info.target.finish)
+ , define.range(lines, text, info.source.start, info.source.finish)
+ )
+ end
+ end
+ return response
+end)
+
+proto.on('textDocument/references', function (params)
+ local core = require 'core.reference'
+ local uri = params.textDocument.uri
+ if not files.exists(uri) then
+ return nil
+ end
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local offset = define.offsetOfWord(lines, text, params.position)
+ local result = core(uri, offset)
+ if not result then
+ return nil
+ end
+ local response = {}
+ for i, info in ipairs(result) do
+ local targetUri = info.uri
+ local targetLines = files.getLines(targetUri)
+ local targetText = files.getText(targetUri)
+ response[i] = define.location(targetUri
+ , define.range(targetLines, targetText, info.target.start, info.target.finish)
+ )
+ end
+ return response
+end)
+
+proto.on('textDocument/documentHighlight', function (params)
+ local core = require 'core.highlight'
+ local uri = params.textDocument.uri
+ if not files.exists(uri) then
+ return nil
+ end
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local offset = define.offsetOfWord(lines, text, params.position)
+ local result = core(uri, offset)
+ if not result then
+ return nil
+ end
+ local response = {}
+ for _, info in ipairs(result) do
+ response[#response+1] = {
+ range = define.range(lines, text, info.start, info.finish),
+ kind = info.kind,
+ }
+ end
+ return response
+end)
+
+proto.on('textDocument/rename', function (params)
+ local core = require 'core.rename'
+ local uri = params.textDocument.uri
+ if not files.exists(uri) then
+ return nil
+ end
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local offset = define.offsetOfWord(lines, text, params.position)
+ local result = core.rename(uri, offset, params.newName)
+ if not result then
+ return nil
+ end
+ local workspaceEdit = {
+ changes = {},
+ }
+ for _, info in ipairs(result) do
+ local ruri = info.uri
+ local rlines = files.getLines(ruri)
+ local rtext = files.getText(ruri)
+ if not workspaceEdit.changes[ruri] then
+ workspaceEdit.changes[ruri] = {}
+ end
+ local textEdit = define.textEdit(define.range(rlines, rtext, info.start, info.finish), info.text)
+ workspaceEdit.changes[ruri][#workspaceEdit.changes[ruri]+1] = textEdit
+ end
+ return workspaceEdit
+end)
+
+proto.on('textDocument/prepareRename', function (params)
+ local core = require 'core.rename'
+ local uri = params.textDocument.uri
+ if not files.exists(uri) then
+ return nil
+ end
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local offset = define.offsetOfWord(lines, text, params.position)
+ local result = core.prepareRename(uri, offset)
+ if not result then
+ return nil
+ end
+ return {
+ range = define.range(lines, text, result.start, result.finish),
+ placeholder = result.text,
+ }
+end)
+
+proto.on('textDocument/completion', function (params)
+ --log.info(util.dump(params))
+ local core = require 'core.completion'
+ --log.debug('completion:', params.context and params.context.triggerKind, params.context and params.context.triggerCharacter)
+ local uri = params.textDocument.uri
+ if not files.exists(uri) then
+ return nil
+ end
+ await.setPriority(1000)
+ local clock = os.clock()
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local offset = define.offset(lines, text, params.position)
+ local result = core.completion(uri, offset)
+ local passed = os.clock() - clock
+ if passed > 0.1 then
+ log.warn(('Completion takes %.3f sec.'):format(passed))
+ end
+ if not result then
+ return nil
+ end
+ local easy = false
+ local items = {}
+ for i, res in ipairs(result) do
+ local item = {
+ label = res.label,
+ kind = res.kind,
+ deprecated = res.deprecated,
+ sortText = ('%04d'):format(i),
+ insertText = res.insertText,
+ insertTextFormat = res.insertTextFormat,
+ textEdit = res.textEdit and {
+ range = define.range(
+ lines,
+ text,
+ res.textEdit.start,
+ res.textEdit.finish
+ ),
+ newText = res.textEdit.newText,
+ },
+ additionalTextEdits = res.additionalTextEdits and (function ()
+ local t = {}
+ for j, edit in ipairs(res.additionalTextEdits) do
+ t[j] = {
+ range = define.range(
+ lines,
+ text,
+ edit.start,
+ edit.finish
+ )
+ }
+ end
+ return t
+ end)(),
+ documentation = res.description and {
+ value = res.description,
+ kind = 'markdown',
+ },
+ }
+ if res.id then
+ if easy and os.clock() - clock < 0.05 then
+ local resolved = core.resolve(res.id)
+ if resolved then
+ item.detail = resolved.detail
+ item.documentation = resolved.description and {
+ value = resolved.description,
+ kind = 'markdown',
+ }
+ end
+ else
+ easy = false
+ item.data = {
+ version = files.globalVersion,
+ id = res.id,
+ }
+ end
+ end
+ items[i] = item
+ end
+ return {
+ isIncomplete = false,
+ items = items,
+ }
+end)
+
+proto.on('completionItem/resolve', function (item)
+ local core = require 'core.completion'
+ if not item.data then
+ return item
+ end
+ local globalVersion = item.data.version
+ local id = item.data.id
+ if globalVersion ~= files.globalVersion then
+ return item
+ end
+ --await.setPriority(1000)
+ local resolved = core.resolve(id)
+ if not resolved then
+ return nil
+ end
+ item.detail = resolved.detail
+ item.documentation = resolved.description and {
+ value = resolved.description,
+ kind = 'markdown',
+ }
+ return item
+end)
+
+proto.on('textDocument/signatureHelp', function (params)
+ if not config.config.signatureHelp.enable then
+ return nil
+ end
+ local uri = params.textDocument.uri
+ if not files.exists(uri) then
+ return nil
+ end
+ await.close('signatureHelp')
+ await.setID('signatureHelp')
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local offset = define.offset(lines, text, params.position)
+ local core = require 'core.signature'
+ local results = core(uri, offset)
+ if not results then
+ return nil
+ end
+ local infos = {}
+ for i, result in ipairs(results) do
+ local parameters = {}
+ for j, param in ipairs(result.params) do
+ parameters[j] = {
+ label = {
+ param.label[1] - 1,
+ param.label[2],
+ }
+ }
+ end
+ infos[i] = {
+ label = result.label,
+ parameters = parameters,
+ activeParameter = result.index - 1,
+ documentation = result.description and {
+ value = result.description,
+ kind = 'markdown',
+ },
+ }
+ end
+ return {
+ signatures = infos,
+ }
+end)
+
+proto.on('textDocument/documentSymbol', function (params)
+ local core = require 'core.document-symbol'
+ local uri = params.textDocument.uri
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ while not lines or not text do
+ await.sleep(0.1)
+ lines = files.getLines(uri)
+ text = files.getText(uri)
+ end
+
+ local symbols = core(uri)
+ if not symbols then
+ return nil
+ end
+
+ local function convert(symbol)
+ await.delay()
+ symbol.range = define.range(
+ lines,
+ text,
+ symbol.range[1],
+ symbol.range[2]
+ )
+ symbol.selectionRange = define.range(
+ lines,
+ text,
+ symbol.selectionRange[1],
+ symbol.selectionRange[2]
+ )
+ if symbol.name == '' then
+ symbol.name = lang.script.SYMBOL_ANONYMOUS
+ end
+ symbol.valueRange = nil
+ if symbol.children then
+ for _, child in ipairs(symbol.children) do
+ convert(child)
+ end
+ end
+ end
+
+ for _, symbol in ipairs(symbols) do
+ convert(symbol)
+ end
+
+ return symbols
+end)
+
+proto.on('textDocument/codeAction', function (params)
+ local core = require 'core.code-action'
+ local uri = params.textDocument.uri
+ local range = params.range
+ local diagnostics = params.context.diagnostics
+ local results = core(uri, range, diagnostics)
+
+ if not results or #results == 0 then
+ return nil
+ end
+
+ return results
+end)
+
+proto.on('workspace/executeCommand', function (params)
+ local command = params.command:gsub(':.+', '')
+ if command == 'lua.removeSpace' then
+ local core = require 'core.command.removeSpace'
+ return core(params.arguments[1])
+ elseif command == 'lua.solve' then
+ local core = require 'core.command.solve'
+ return core(params.arguments[1])
+ end
+end)
+
+proto.on('workspace/symbol', function (params)
+ local core = require 'core.workspace-symbol'
+
+ await.close('workspace/symbol')
+ await.setID('workspace/symbol')
+
+ local symbols = core(params.query)
+ if not symbols or #symbols == 0 then
+ return nil
+ end
+
+ local function convert(symbol)
+ symbol.location = define.location(
+ symbol.uri,
+ define.range(
+ files.getLines(symbol.uri),
+ files.getText(symbol.uri),
+ symbol.range[1],
+ symbol.range[2]
+ )
+ )
+ symbol.uri = nil
+ end
+
+ for _, symbol in ipairs(symbols) do
+ convert(symbol)
+ end
+
+ return symbols
+end)
+
+
+proto.on('textDocument/semanticTokens/full', function (params)
+ local core = require 'core.semantic-tokens'
+ local uri = params.textDocument.uri
+ log.debug('semanticTokens/full', uri)
+ local text = files.getText(uri)
+ while not text do
+ await.sleep(0.1)
+ text = files.getText(uri)
+ end
+ local results = core(uri, 0, #text)
+ if not results or #results == 0 then
+ return nil
+ end
+ return {
+ data = results
+ }
+end)
+
+proto.on('textDocument/semanticTokens/range', function (params)
+ local core = require 'core.semantic-tokens'
+ local uri = params.textDocument.uri
+ log.debug('semanticTokens/range', uri)
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ while not lines or not text do
+ await.sleep(0.1)
+ lines = files.getLines(uri)
+ text = files.getText(uri)
+ end
+ local start = define.offset(lines, text, params.range.start)
+ local finish = define.offset(lines, text, params.range['end'])
+ local results = core(uri, start, finish)
+ if not results or #results == 0 then
+ return nil
+ end
+ return {
+ data = results
+ }
+end)
diff --git a/script/provider/semantic-tokens.lua b/script/provider/semantic-tokens.lua
new file mode 100644
index 00000000..17985bcd
--- /dev/null
+++ b/script/provider/semantic-tokens.lua
@@ -0,0 +1,64 @@
+local proto = require 'proto'
+local define = require 'proto.define'
+local client = require 'provider.client'
+
+local isEnable = false
+
+local function toArray(map)
+ local array = {}
+ for k in pairs(map) do
+ array[#array+1] = k
+ end
+ table.sort(array, function (a, b)
+ return map[a] < map[b]
+ end)
+ return array
+end
+
+local function enable()
+ if isEnable then
+ return
+ end
+ if not client.info.capabilities.textDocument.semanticTokens then
+ return
+ end
+ isEnable = true
+ log.debug('Enable semantic tokens.')
+ proto.awaitRequest('client/registerCapability', {
+ registrations = {
+ {
+ id = 'semantic-tokens',
+ method = 'textDocument/semanticTokens',
+ registerOptions = {
+ legend = {
+ tokenTypes = toArray(define.TokenTypes),
+ tokenModifiers = toArray(define.TokenModifiers),
+ },
+ range = true,
+ full = true,
+ },
+ },
+ }
+ })
+end
+
+local function disable()
+ if not isEnable then
+ return
+ end
+ isEnable = false
+ log.debug('Disable semantic tokens.')
+ proto.awaitRequest('client/unregisterCapability', {
+ unregisterations = {
+ {
+ id = 'semantic-tokens',
+ method = 'textDocument/semanticTokens',
+ },
+ }
+ })
+end
+
+return {
+ enable = enable,
+ disable = disable,
+}
diff --git a/script/pub/init.lua b/script/pub/init.lua
new file mode 100644
index 00000000..61b43da7
--- /dev/null
+++ b/script/pub/init.lua
@@ -0,0 +1,4 @@
+local pub = require 'pub.pub'
+require 'pub.report'
+
+return pub
diff --git a/script/pub/pub.lua b/script/pub/pub.lua
new file mode 100644
index 00000000..ad1cd749
--- /dev/null
+++ b/script/pub/pub.lua
@@ -0,0 +1,242 @@
+local thread = require 'bee.thread'
+local utility = require 'utility'
+local await = require 'await'
+local timer = require 'timer'
+
+local errLog = thread.channel 'errlog'
+local type = type
+local counter = utility.counter()
+
+local braveTemplate = [[
+package.path = %q
+package.cpath = %q
+DEVELOP = %s
+DBGPORT = %d
+DBGWAIT = %s
+
+collectgarbage 'generational'
+
+log = require 'brave.log'
+
+xpcall(dofile, log.error, %q)
+local brave = require 'brave'
+brave.register(%d)
+]]
+
+---@class pub
+local m = {}
+m.type = 'pub'
+m.braves = {}
+m.ability = {}
+m.taskQueue = {}
+
+--- 注册酒馆的功能
+function m.on(name, callback)
+ m.ability[name] = callback
+end
+
+--- 招募勇者,勇者会从公告板上领取任务,完成任务后到看板娘处交付任务
+---@param num integer
+function m.recruitBraves(num)
+ for _ = 1, num do
+ local id = #m.braves + 1
+ log.info('Create brave:', id)
+ thread.newchannel('taskpad' .. id)
+ thread.newchannel('waiter' .. id)
+ m.braves[id] = {
+ id = id,
+ taskpad = thread.channel('taskpad' .. id),
+ waiter = thread.channel('waiter' .. id),
+ thread = thread.thread(braveTemplate:format(
+ package.path,
+ package.cpath,
+ DEVELOP,
+ DBGPORT or 11412,
+ DBGWAIT or 'nil',
+ (ROOT / 'debugger.lua'):string(),
+ id
+ )),
+ taskMap = {},
+ currentTask = nil,
+ memory = 0,
+ }
+ end
+end
+
+--- 勇者是否有空
+function m.isIdle(brave)
+ return next(brave.taskMap) == nil
+end
+
+--- 给勇者推送任务
+function m.pushTask(brave, info)
+ if info.removed then
+ return false
+ end
+ brave.taskpad:push(info.name, info.id, info.params)
+ brave.taskMap[info.id] = info
+ --log.info(('Push task %q(%d) to # %d, queue length %d'):format(info.name, info.id, brave.id, #m.taskQueue))
+ return true
+end
+
+--- 从勇者处接收任务反馈
+function m.popTask(brave, id, result)
+ local info = brave.taskMap[id]
+ if not info then
+ log.warn(('Brave pushed unknown task result: # %d => [%d]'):format(brave.id, id))
+ return
+ end
+ brave.taskMap[id] = nil
+ --log.info(('Pop task %q(%d) from # %d'):format(info.name, info.id, brave.id))
+ m.checkWaitingTask(brave)
+ if not info.removed then
+ info.removed = true
+ if info.callback then
+ xpcall(info.callback, log.error, result)
+ end
+ end
+end
+
+--- 从勇者处接收报告
+function m.popReport(brave, name, params)
+ local abil = m.ability[name]
+ if not abil then
+ log.warn(('Brave pushed unknown report: # %d => %q'):format(brave.id, name))
+ return
+ end
+ xpcall(abil, log.error, params, brave)
+end
+
+--- 发布任务
+---@parma name string
+---@param params any
+function m.awaitTask(name, params)
+ local info = {
+ id = counter(),
+ name = name,
+ params = params,
+ }
+ for _, brave in ipairs(m.braves) do
+ if m.isIdle(brave) then
+ if m.pushTask(brave, info) then
+ return await.wait(function (waker)
+ info.callback = waker
+ end)
+ else
+ return nil
+ end
+ end
+ end
+ -- 如果所有勇者都在战斗,那么把任务缓存到队列里
+ -- 当有勇者提交任务反馈后,尝试把按顺序将堆积任务
+ -- 交给该勇者
+ m.taskQueue[#m.taskQueue+1] = info
+ --log.info(('Add task %q(%d) in queue, length %d.'):format(name, info.id, #m.taskQueue))
+ return await.wait(function (waker)
+ info.callback = waker
+ end)
+end
+
+--- 发布同步任务,如果任务进入了队列,会返回执行器
+--- 通过 jumpQueue 可以插队
+---@parma name string
+---@param params any
+---@param callback function
+function m.task(name, params, callback)
+ local info = {
+ id = counter(),
+ name = name,
+ params = params,
+ callback = callback,
+ }
+ for _, brave in ipairs(m.braves) do
+ if m.isIdle(brave) then
+ m.pushTask(brave, info)
+ return nil
+ end
+ end
+ -- 如果所有勇者都在战斗,那么把任务缓存到队列里
+ -- 当有勇者提交任务反馈后,尝试把按顺序将堆积任务
+ -- 交给该勇者
+ m.taskQueue[#m.taskQueue+1] = info
+ --log.info(('Add task %q(%d) in queue, length %d.'):format(name, info.id, #m.taskQueue))
+ return info
+end
+
+--- 插队
+function m.jumpQueue(info)
+ for i = 2, #m.taskQueue do
+ if m.taskQueue[i] == info then
+ m.taskQueue[i] = nil
+ table.move(m.taskQueue, 1, i - 1, 2)
+ m.taskQueue[1] = info
+ return
+ end
+ end
+end
+
+--- 移除任务
+function m.remove(info)
+ info.removed = true
+ for i = 1, #m.taskQueue do
+ if m.taskQueue[i] == info then
+ table.remove(m.taskQueue[i], i)
+ return
+ end
+ end
+end
+
+--- 检查堆积任务
+function m.checkWaitingTask(brave)
+ if #m.taskQueue == 0 then
+ return
+ end
+ -- 如果勇者还有其他活要忙,那么让他继续忙去吧
+ if next(brave.taskMap) then
+ return
+ end
+ while #m.taskQueue > 0 do
+ local info = table.remove(m.taskQueue, 1)
+ if m.pushTask(brave, info) then
+ break
+ end
+ end
+end
+
+--- 接收反馈
+---|返回接收到的反馈数量
+---@return integer
+function m.recieve()
+ for _, brave in ipairs(m.braves) do
+ while true do
+ local suc, id, result = brave.waiter:pop()
+ if not suc then
+ goto CONTINUE
+ end
+ if type(id) == 'string' then
+ m.popReport(brave, id, result)
+ else
+ m.popTask(brave, id, result)
+ end
+ end
+ ::CONTINUE::
+ end
+end
+
+--- 检查伤亡情况
+function m.checkDead()
+ while true do
+ local suc, err = errLog:pop()
+ if not suc then
+ break
+ end
+ log.error('Brave is dead!: ' .. err)
+ end
+end
+
+function m.step()
+ m.checkDead()
+ m.recieve()
+end
+
+return m
diff --git a/script/pub/report.lua b/script/pub/report.lua
new file mode 100644
index 00000000..549277e1
--- /dev/null
+++ b/script/pub/report.lua
@@ -0,0 +1,26 @@
+local pub = require 'pub.pub'
+local await = require 'await'
+
+pub.on('log', function (params, brave)
+ log.raw(brave.id, params.level, params.msg, params.src, params.line, params.clock)
+end)
+
+pub.on('mem', function (count, brave)
+ brave.memory = count
+end)
+
+pub.on('proto', function (params)
+ local proto = require 'proto'
+ await.call(function ()
+ if params.method then
+ proto.doMethod(params)
+ else
+ proto.doResponse(params)
+ end
+ end)
+end)
+
+pub.on('protoerror', function (err)
+ log.warn('Load proto error:', err)
+ os.exit(true)
+end)
diff --git a/script/service/init.lua b/script/service/init.lua
new file mode 100644
index 00000000..eb0bd057
--- /dev/null
+++ b/script/service/init.lua
@@ -0,0 +1,3 @@
+local service = require 'service.service'
+
+return service
diff --git a/script/service/service.lua b/script/service/service.lua
new file mode 100644
index 00000000..11cc7b19
--- /dev/null
+++ b/script/service/service.lua
@@ -0,0 +1,158 @@
+local pub = require 'pub'
+local thread = require 'bee.thread'
+local await = require 'await'
+local timer = require 'timer'
+local proto = require 'proto'
+local vm = require 'vm'
+
+local m = {}
+m.type = 'service'
+
+local function countMemory()
+ local mems = {}
+ local total = 0
+ mems[0] = collectgarbage 'count'
+ total = total + collectgarbage 'count'
+ for id, brave in ipairs(pub.braves) do
+ mems[id] = brave.memory
+ total = total + brave.memory
+ end
+ return total, mems
+end
+
+function m.reportMemoryCollect()
+ local totalMemBefore = countMemory()
+ local clock = os.clock()
+ collectgarbage()
+ local passed = os.clock() - clock
+ local totalMemAfter, mems = countMemory()
+
+ local lines = {}
+ lines[#lines+1] = ' --------------- Memory ---------------'
+ lines[#lines+1] = (' Total: %.3f(%.3f) MB'):format(totalMemAfter / 1000.0, totalMemBefore / 1000.0)
+ for i = 0, #mems do
+ lines[#lines+1] = (' # %02d : %.3f MB'):format(i, mems[i] / 1000.0)
+ end
+ lines[#lines+1] = (' Collect garbage takes [%.3f] sec'):format(passed)
+ return table.concat(lines, '\n')
+end
+
+function m.reportMemory()
+ local totalMem, mems = countMemory()
+
+ local lines = {}
+ lines[#lines+1] = ' --------------- Memory ---------------'
+ lines[#lines+1] = (' Total: %.3f MB'):format(totalMem / 1000.0)
+ for i = 0, #mems do
+ lines[#lines+1] = (' # %02d : %.3f MB'):format(i, mems[i] / 1000.0)
+ end
+ return table.concat(lines, '\n')
+end
+
+function m.reportTask()
+ local total = 0
+ local running = 0
+ local suspended = 0
+ local normal = 0
+ local dead = 0
+
+ for co in pairs(await.coMap) do
+ total = total + 1
+ local status = coroutine.status(co)
+ if status == 'running' then
+ running = running + 1
+ elseif status == 'suspended' then
+ suspended = suspended + 1
+ elseif status == 'normal' then
+ normal = normal + 1
+ elseif status == 'dead' then
+ dead = dead + 1
+ end
+ end
+
+ local lines = {}
+ lines[#lines+1] = ' --------------- Coroutine ---------------'
+ lines[#lines+1] = (' Total: %d'):format(total)
+ lines[#lines+1] = (' Running: %d'):format(running)
+ lines[#lines+1] = (' Suspended: %d'):format(suspended)
+ lines[#lines+1] = (' Normal: %d'):format(normal)
+ lines[#lines+1] = (' Dead: %d'):format(dead)
+ return table.concat(lines, '\n')
+end
+
+function m.reportCache()
+ local total = 0
+ local dead = 0
+
+ for cache in pairs(vm.cacheTracker) do
+ total = total + 1
+ if cache.dead then
+ dead = dead + 1
+ end
+ end
+
+ local lines = {}
+ lines[#lines+1] = ' --------------- Cache ---------------'
+ lines[#lines+1] = (' Total: %d'):format(total)
+ lines[#lines+1] = (' Dead: %d'):format(dead)
+ return table.concat(lines, '\n')
+end
+
+function m.reportProto()
+ local holdon = 0
+ local waiting = 0
+
+ for _ in pairs(proto.holdon) do
+ holdon = holdon + 1
+ end
+ for _ in pairs(proto.waiting) do
+ waiting = waiting + 1
+ end
+
+ local lines = {}
+ lines[#lines+1] = ' --------------- Proto ---------------'
+ lines[#lines+1] = (' Holdon: %d'):format(holdon)
+ lines[#lines+1] = (' Waiting: %d'):format(waiting)
+ return table.concat(lines, '\n')
+end
+
+function m.report()
+ local t = timer.loop(60.0, function ()
+ local lines = {}
+ lines[#lines+1] = ''
+ lines[#lines+1] = '========= Medical Examination Report ========='
+ lines[#lines+1] = m.reportMemory()
+ lines[#lines+1] = m.reportTask()
+ lines[#lines+1] = m.reportCache()
+ lines[#lines+1] = m.reportProto()
+ lines[#lines+1] = '=============================================='
+
+ log.debug(table.concat(lines, '\n'))
+ end)
+ t:onTimer()
+end
+
+function m.startTimer()
+ while true do
+ ::CONTINUE::
+ pub.step()
+ if await.step() then
+ goto CONTINUE
+ end
+ thread.sleep(0.001)
+ timer.update()
+ end
+end
+
+function m.start()
+ await.setErrorHandle(log.error)
+ pub.recruitBraves(4)
+ proto.listen()
+ m.report()
+
+ require 'provider'
+
+ m.startTimer()
+end
+
+return m
diff --git a/script/timer.lua b/script/timer.lua
new file mode 100644
index 00000000..1d4343f1
--- /dev/null
+++ b/script/timer.lua
@@ -0,0 +1,218 @@
+local setmetatable = setmetatable
+local mathMax = math.max
+local mathFloor = math.floor
+local osClock = os.clock
+
+_ENV = nil
+
+local curFrame = 0
+local maxFrame = 0
+local curIndex = 0
+local freeQueue = {}
+local timer = {}
+
+local function allocQueue()
+ local n = #freeQueue
+ if n > 0 then
+ local r = freeQueue[n]
+ freeQueue[n] = nil
+ return r
+ else
+ return {}
+ end
+end
+
+local function mTimeout(self, timeout)
+ if self._pauseRemaining or self._running then
+ return
+ end
+ local ti = curFrame + timeout
+ local q = timer[ti]
+ if q == nil then
+ q = allocQueue()
+ timer[ti] = q
+ end
+ self._timeoutFrame = ti
+ self._running = true
+ q[#q + 1] = self
+end
+
+local function mWakeup(self)
+ if self._removed then
+ return
+ end
+ self._running = false
+ if self._onTimer then
+ self:_onTimer()
+ end
+ if self._removed then
+ return
+ end
+ if self._timerCount then
+ if self._timerCount > 1 then
+ self._timerCount = self._timerCount - 1
+ mTimeout(self, self._timeout)
+ else
+ self._removed = true
+ end
+ else
+ mTimeout(self, self._timeout)
+ end
+end
+
+local function getRemaining(self)
+ if self._removed then
+ return 0
+ end
+ if self._pauseRemaining then
+ return self._pauseRemaining
+ end
+ if self._timeoutFrame == curFrame then
+ return self._timeout or 0
+ end
+ return self._timeoutFrame - curFrame
+end
+
+local function onTick()
+ local q = timer[curFrame]
+ if q == nil then
+ curIndex = 0
+ return
+ end
+ for i = curIndex + 1, #q do
+ local callback = q[i]
+ curIndex = i
+ q[i] = nil
+ if callback then
+ mWakeup(callback)
+ end
+ end
+ curIndex = 0
+ timer[curFrame] = nil
+ freeQueue[#freeQueue + 1] = q
+end
+
+local m = {}
+local mt = {}
+mt.__index = mt
+mt.type = 'timer'
+
+function mt:__tostring()
+ return '[table:timer]'
+end
+
+function mt:__call()
+ if self._onTimer then
+ self:_onTimer()
+ end
+end
+
+function mt:remove()
+ self._removed = true
+end
+
+function mt:pause()
+ if self._removed or self._pauseRemaining then
+ return
+ end
+ self._pauseRemaining = getRemaining(self)
+ self._running = false
+ local ti = self._timeoutFrame
+ local q = timer[ti]
+ if q then
+ for i = #q, 1, -1 do
+ if q[i] == self then
+ q[i] = false
+ return
+ end
+ end
+ end
+end
+
+function mt:resume()
+ if self._removed or not self._pauseRemaining then
+ return
+ end
+ local timeout = self._pauseRemaining
+ self._pauseRemaining = nil
+ mTimeout(self, timeout)
+end
+
+function mt:restart()
+ if self._removed or self._pauseRemaining or not self._running then
+ return
+ end
+ local ti = self._timeoutFrame
+ local q = timer[ti]
+ if q then
+ for i = #q, 1, -1 do
+ if q[i] == self then
+ q[i] = false
+ break
+ end
+ end
+ end
+ self._running = false
+ mTimeout(self, self._timeout)
+end
+
+function mt:remaining()
+ return getRemaining(self) / 1000.0
+end
+
+function mt:onTimer()
+ self:_onTimer()
+end
+
+function m.wait(timeout, onTimer)
+ local t = setmetatable({
+ ['_timeout'] = mathMax(mathFloor(timeout * 1000.0), 1),
+ ['_onTimer'] = onTimer,
+ ['_timerCount'] = 1,
+ }, mt)
+ mTimeout(t, t._timeout)
+ return t
+end
+
+function m.loop(timeout, onTimer)
+ local t = setmetatable({
+ ['_timeout'] = mathFloor(timeout * 1000.0),
+ ['_onTimer'] = onTimer,
+ }, mt)
+ mTimeout(t, t._timeout)
+ return t
+end
+
+function m.timer(timeout, count, onTimer)
+ if count == 0 then
+ return m.loop(timeout, onTimer)
+ end
+ local t = setmetatable({
+ ['_timeout'] = mathFloor(timeout * 1000.0),
+ ['_onTimer'] = onTimer,
+ ['_timerCount'] = count,
+ }, mt)
+ mTimeout(t, t._timeout)
+ return t
+end
+
+function m.clock()
+ return curFrame / 1000.0
+end
+
+local lastClock = osClock()
+function m.update()
+ local currentClock = osClock()
+ local delta = currentClock - lastClock
+ lastClock = currentClock
+ if curIndex ~= 0 then
+ curFrame = curFrame - 1
+ end
+ maxFrame = maxFrame + delta * 1000.0
+ while curFrame < maxFrame do
+ curFrame = curFrame + 1
+ onTick()
+ end
+end
+
+return m
diff --git a/script/utility.lua b/script/utility.lua
new file mode 100644
index 00000000..a1ea1804
--- /dev/null
+++ b/script/utility.lua
@@ -0,0 +1,559 @@
+local tableSort = table.sort
+local stringRep = string.rep
+local tableConcat = table.concat
+local tostring = tostring
+local type = type
+local pairs = pairs
+local ipairs = ipairs
+local next = next
+local rawset = rawset
+local move = table.move
+local setmetatable = setmetatable
+local mathType = math.type
+local mathCeil = math.ceil
+local getmetatable = getmetatable
+local mathAbs = math.abs
+local mathRandom = math.random
+local ioOpen = io.open
+local utf8Len = utf8.len
+local mathHuge = math.huge
+local inf = 1 / 0
+local nan = 0 / 0
+
+_ENV = nil
+
+local function formatNumber(n)
+ if n == inf
+ or n == -inf
+ or n == nan
+ or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则
+ return ('%q'):format(n)
+ end
+ local str = ('%.10f'):format(n)
+ str = str:gsub('%.?0*$', '')
+ return str
+end
+
+local function isInteger(n)
+ if mathType then
+ return mathType(n) == 'integer'
+ else
+ return type(n) == 'number' and n % 1 == 0
+ end
+end
+
+local TAB = setmetatable({}, { __index = function (self, n)
+ self[n] = stringRep(' ', n)
+ return self[n]
+end})
+
+local RESERVED = {
+ ['and'] = true,
+ ['break'] = true,
+ ['do'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['end'] = true,
+ ['false'] = true,
+ ['for'] = true,
+ ['function'] = true,
+ ['goto'] = true,
+ ['if'] = true,
+ ['in'] = true,
+ ['local'] = true,
+ ['nil'] = true,
+ ['not'] = true,
+ ['or'] = true,
+ ['repeat'] = true,
+ ['return'] = true,
+ ['then'] = true,
+ ['true'] = true,
+ ['until'] = true,
+ ['while'] = true,
+}
+
+local m = {}
+
+--- 打印表的结构
+---@param tbl table
+---@param option table {optional = 'self'}
+---@return string
+function m.dump(tbl, option)
+ if not option then
+ option = {}
+ end
+ if type(tbl) ~= 'table' then
+ return ('%s'):format(tbl)
+ end
+ local lines = {}
+ local mark = {}
+ lines[#lines+1] = '{'
+ local function unpack(tbl, deep)
+ mark[tbl] = (mark[tbl] or 0) + 1
+ local keys = {}
+ local keymap = {}
+ local integerFormat = '[%d]'
+ local alignment = 0
+ if #tbl >= 10 then
+ local width = #tostring(#tbl)
+ integerFormat = ('[%%0%dd]'):format(mathCeil(width))
+ end
+ for key in pairs(tbl) do
+ if type(key) == 'string' then
+ if not key:match('^[%a_][%w_]*$')
+ or RESERVED[key]
+ or option['longStringKey']
+ then
+ keymap[key] = ('[%q]'):format(key)
+ else
+ keymap[key] = ('%s'):format(key)
+ end
+ elseif isInteger(key) then
+ keymap[key] = integerFormat:format(key)
+ else
+ keymap[key] = ('["<%s>"]'):format(tostring(key))
+ end
+ keys[#keys+1] = key
+ if option['alignment'] then
+ if #keymap[key] > alignment then
+ alignment = #keymap[key]
+ end
+ end
+ end
+ local mt = getmetatable(tbl)
+ if not mt or not mt.__pairs then
+ if option['sorter'] then
+ option['sorter'](keys, keymap)
+ else
+ tableSort(keys, function (a, b)
+ return keymap[a] < keymap[b]
+ end)
+ end
+ end
+ for _, key in ipairs(keys) do
+ local keyWord = keymap[key]
+ if option['noArrayKey']
+ and isInteger(key)
+ and key <= #tbl
+ then
+ keyWord = ''
+ else
+ if #keyWord < alignment then
+ keyWord = keyWord .. (' '):rep(alignment - #keyWord) .. ' = '
+ else
+ keyWord = keyWord .. ' = '
+ end
+ end
+ local value = tbl[key]
+ local tp = type(value)
+ if option['format'] and option['format'][key] then
+ lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, option['format'][key](value, unpack, deep+1))
+ elseif tp == 'table' then
+ if mark[value] and mark[value] > 0 then
+ lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, option['loop'] or '"<Loop>"')
+ elseif deep >= (option['deep'] or mathHuge) then
+ lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, '"<Deep>"')
+ else
+ lines[#lines+1] = ('%s%s{'):format(TAB[deep+1], keyWord)
+ unpack(value, deep+1)
+ lines[#lines+1] = ('%s},'):format(TAB[deep+1])
+ end
+ elseif tp == 'string' then
+ lines[#lines+1] = ('%s%s%q,'):format(TAB[deep+1], keyWord, value)
+ elseif tp == 'number' then
+ lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, (option['number'] or formatNumber)(value))
+ elseif tp == 'nil' then
+ else
+ lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, tostring(value))
+ end
+ end
+ mark[tbl] = mark[tbl] - 1
+ end
+ unpack(tbl, 0)
+ lines[#lines+1] = '}'
+ return tableConcat(lines, '\r\n')
+end
+
+--- 递归判断A与B是否相等
+---@param a any
+---@param b any
+---@return boolean
+function m.equal(a, b)
+ local tp1 = type(a)
+ local tp2 = type(b)
+ if tp1 ~= tp2 then
+ return false
+ end
+ if tp1 == 'table' then
+ local mark = {}
+ for k, v in pairs(a) do
+ mark[k] = true
+ local res = m.equal(v, b[k])
+ if not res then
+ return false
+ end
+ end
+ for k in pairs(b) do
+ if not mark[k] then
+ return false
+ end
+ end
+ return true
+ elseif tp1 == 'number' then
+ return mathAbs(a - b) <= 1e-10
+ else
+ return a == b
+ end
+end
+
+local function sortTable(tbl)
+ if not tbl then
+ tbl = {}
+ end
+ local mt = {}
+ local keys = {}
+ local mark = {}
+ local n = 0
+ for key in next, tbl do
+ n=n+1;keys[n] = key
+ mark[key] = true
+ end
+ tableSort(keys)
+ function mt:__newindex(key, value)
+ rawset(self, key, value)
+ n=n+1;keys[n] = key
+ mark[key] = true
+ if type(value) == 'table' then
+ sortTable(value)
+ end
+ end
+ function mt:__pairs()
+ local list = {}
+ local m = 0
+ for key in next, self do
+ if not mark[key] then
+ m=m+1;list[m] = key
+ end
+ end
+ if m > 0 then
+ move(keys, 1, n, m+1)
+ tableSort(list)
+ for i = 1, m do
+ local key = list[i]
+ keys[i] = key
+ mark[key] = true
+ end
+ n = n + m
+ end
+ local i = 0
+ return function ()
+ i = i + 1
+ local key = keys[i]
+ return key, self[key]
+ end
+ end
+
+ return setmetatable(tbl, mt)
+end
+
+--- 创建一个有序表
+---@param tbl table {optional = 'self'}
+---@return table
+function m.container(tbl)
+ return sortTable(tbl)
+end
+
+--- 读取文件
+---@param path string
+function m.loadFile(path)
+ local f, e = ioOpen(path, 'rb')
+ if not f then
+ return nil, e
+ end
+ if f:read(3) ~= '\xEF\xBB\xBF' then
+ f:seek("set")
+ end
+ local buf = f:read 'a'
+ f:close()
+ return buf
+end
+
+--- 写入文件
+---@param path string
+---@param content string
+function m.saveFile(path, content)
+ local f, e = ioOpen(path, "wb")
+
+ if f then
+ f:write(content)
+ f:close()
+ return true
+ else
+ return false, e
+ end
+end
+
+--- 计数器
+---@param init integer {optional = 'after'}
+---@param step integer {optional = 'after'}
+---@return fun():integer
+function m.counter(init, step)
+ if not step then
+ step = 1
+ end
+ local current = init and (init - 1) or 0
+ return function ()
+ current = current + step
+ return current
+ end
+end
+
+--- 排序后遍历
+---@param t table
+function m.sortPairs(t)
+ local keys = {}
+ for k in pairs(t) do
+ keys[#keys+1] = k
+ end
+ tableSort(keys)
+ local i = 0
+ return function ()
+ i = i + 1
+ local k = keys[i]
+ return k, t[k]
+ end
+end
+
+--- 深拷贝(不处理元表)
+---@param source table
+---@param target table {optional = 'self'}
+function m.deepCopy(source, target)
+ local mark = {}
+ local function copy(a, b)
+ if type(a) ~= 'table' then
+ return a
+ end
+ if mark[a] then
+ return mark[a]
+ end
+ if not b then
+ b = {}
+ end
+ mark[a] = b
+ for k, v in pairs(a) do
+ b[copy(k)] = copy(v)
+ end
+ return b
+ end
+ return copy(source, target)
+end
+
+--- 序列化
+function m.unpack(t)
+ local result = {}
+ local tid = 0
+ local cache = {}
+ local function unpack(o)
+ local id = cache[o]
+ if not id then
+ tid = tid + 1
+ id = tid
+ cache[o] = tid
+ if type(o) == 'table' then
+ local new = {}
+ result[tid] = new
+ for k, v in next, o do
+ new[unpack(k)] = unpack(v)
+ end
+ else
+ result[id] = o
+ end
+ end
+ return id
+ end
+ unpack(t)
+ return result
+end
+
+--- 反序列化
+function m.pack(t)
+ local cache = {}
+ local function pack(id)
+ local o = cache[id]
+ if o then
+ return o
+ end
+ o = t[id]
+ if type(o) == 'table' then
+ local new = {}
+ cache[id] = new
+ for k, v in next, o do
+ new[pack(k)] = pack(v)
+ end
+ return new
+ else
+ cache[id] = o
+ return o
+ end
+ end
+ return pack(1)
+end
+
+--- defer
+local deferMT = { __close = function (self) self[1]() end }
+function m.defer(callback)
+ return setmetatable({ callback }, deferMT)
+end
+
+local esc = {
+ ["'"] = [[\']],
+ ['"'] = [[\"]],
+ ['\r'] = [[\r]],
+ ['\n'] = '\\\n',
+}
+
+function m.viewString(str, quo)
+ if not quo then
+ if str:find('[\r\n]') then
+ quo = '[['
+ elseif not str:find("'", 1, true) and str:find('"', 1, true) then
+ quo = "'"
+ else
+ quo = '"'
+ end
+ end
+ if quo == "'" then
+ str = str:gsub('[\000-\008\011-\012\014-\031\127]', function (char)
+ return ('\\%03d'):format(char:byte())
+ end)
+ return quo .. str:gsub([=[['\r\n]]=], esc) .. quo
+ elseif quo == '"' then
+ str = str:gsub('[\000-\008\011-\012\014-\031\127]', function (char)
+ return ('\\%03d'):format(char:byte())
+ end)
+ return quo .. str:gsub([=[["\r\n]]=], esc) .. quo
+ else
+ local eqnum = #quo - 2
+ local fsymb = ']' .. ('='):rep(eqnum) .. ']'
+ if not str:find(fsymb, 1, true) then
+ str = str:gsub('[\000-\008\011-\012\014-\031\127]', '')
+ return quo .. str .. fsymb
+ end
+ for i = 0, 10 do
+ local fsymb = ']' .. ('='):rep(i) .. ']'
+ if not str:find(fsymb, 1, true) then
+ local ssymb = '[' .. ('='):rep(i) .. '['
+ str = str:gsub('[\000-\008\011-\012\014-\031\127]', '')
+ return ssymb .. str .. fsymb
+ end
+ end
+ return m.viewString(str, '"')
+ end
+end
+
+function m.viewLiteral(v)
+ local tp = type(v)
+ if tp == 'nil' then
+ return 'nil'
+ elseif tp == 'string' then
+ return m.viewString(v)
+ elseif tp == 'boolean' then
+ return tostring(v)
+ elseif tp == 'number' then
+ if isInteger(v) then
+ return tostring(v)
+ else
+ return formatNumber(v)
+ end
+ end
+ return nil
+end
+
+function m.utf8Len(str, start, finish)
+ local len, pos = utf8Len(str, start, finish, true)
+ if len then
+ return len
+ end
+ return 1 + m.utf8Len(str, start, pos-1) + m.utf8Len(str, pos+1, finish)
+end
+
+function m.revertTable(t)
+ local len = #t
+ if len <= 1 then
+ return t
+ end
+ for x = 1, len // 2 do
+ local y = len - x + 1
+ t[x], t[y] = t[y], t[x]
+ end
+ return t
+end
+
+function m.randomSortTable(t, max)
+ local len = #t
+ if len <= 1 then
+ return t
+ end
+ if not max or max > len then
+ max = len
+ end
+ for x = 1, max do
+ local y = mathRandom(len)
+ t[x], t[y] = t[y], t[x]
+ end
+ return t
+end
+
+function m.tableMultiRemove(t, index)
+ local mark = {}
+ for i = 1, #index do
+ local v = index[i]
+ mark[v] = true
+ end
+ local offset = 0
+ local me = 1
+ local len = #t
+ while true do
+ local it = me + offset
+ if it > len then
+ for i = me, len do
+ t[i] = nil
+ end
+ break
+ end
+ if mark[it] then
+ offset = offset + 1
+ else
+ if me ~= it then
+ t[me] = t[it]
+ end
+ me = me + 1
+ end
+ end
+end
+
+function m.eachLine(text)
+ local offset = 1
+ local lineCount = 0
+ return function ()
+ if offset > #text then
+ return nil
+ end
+ lineCount = lineCount + 1
+ local nl = text:find('[\r\n]', offset)
+ if not nl then
+ local lastLine = text:sub(offset)
+ offset = #text + 1
+ return lastLine
+ end
+ local line = text:sub(offset, nl - 1)
+ if text:sub(nl, nl + 1) == '\r\n' then
+ offset = nl + 2
+ else
+ offset = nl + 1
+ end
+ return line
+ end
+end
+
+return m
diff --git a/script/vm/eachDef.lua b/script/vm/eachDef.lua
new file mode 100644
index 00000000..5ff58889
--- /dev/null
+++ b/script/vm/eachDef.lua
@@ -0,0 +1,40 @@
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local files = require 'files'
+local util = require 'utility'
+local await = require 'await'
+
+local function eachDef(source, deep)
+ local results = {}
+ local lock = vm.lock('eachDef', source)
+ if not lock then
+ return results
+ end
+
+ await.delay()
+
+ local clock = os.clock()
+ local myResults, count = guide.requestDefinition(source, vm.interface, deep)
+ if DEVELOP and os.clock() - clock > 0.1 then
+ log.warn('requestDefinition', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
+ end
+ vm.mergeResults(results, myResults)
+
+ lock()
+
+ return results
+end
+
+function vm.getDefs(source, deep)
+ if guide.isGlobal(source) then
+ local key = guide.getKeyName(source)
+ return vm.getGlobalSets(key)
+ else
+ local cache = vm.getCache('eachDef')[source]
+ or eachDef(source, deep)
+ if deep then
+ vm.getCache('eachDef')[source] = cache
+ end
+ return cache
+ end
+end
diff --git a/script/vm/eachField.lua b/script/vm/eachField.lua
new file mode 100644
index 00000000..ce0e3928
--- /dev/null
+++ b/script/vm/eachField.lua
@@ -0,0 +1,45 @@
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local await = require 'await'
+
+local function eachField(source, deep)
+ local unlock = vm.lock('eachField', source)
+ if not unlock then
+ return {}
+ end
+
+ while source.type == 'paren' do
+ source = source.exp
+ if not source then
+ return {}
+ end
+ end
+
+ await.delay()
+ local results = guide.requestFields(source, vm.interface, deep)
+
+ unlock()
+ return results
+end
+
+function vm.getFields(source, deep)
+ if source.special == '_G' then
+ return vm.getGlobals '*'
+ end
+ if guide.isGlobal(source) then
+ local name = guide.getKeyName(source)
+ local cache = vm.getCache('eachFieldOfGlobal')[name]
+ or vm.getCache('eachField')[source]
+ or eachField(source, 'deep')
+ vm.getCache('eachFieldOfGlobal')[name] = cache
+ vm.getCache('eachField')[source] = cache
+ return cache
+ else
+ local cache = vm.getCache('eachField')[source]
+ or eachField(source, deep)
+ if deep then
+ vm.getCache('eachField')[source] = cache
+ end
+ return cache
+ end
+end
diff --git a/script/vm/eachRef.lua b/script/vm/eachRef.lua
new file mode 100644
index 00000000..4e735abf
--- /dev/null
+++ b/script/vm/eachRef.lua
@@ -0,0 +1,39 @@
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local util = require 'utility'
+local await = require 'await'
+
+local function getRefs(source, deep)
+ local results = {}
+ local lock = vm.lock('eachRef', source)
+ if not lock then
+ return results
+ end
+
+ await.delay()
+
+ local clock = os.clock()
+ local myResults, count = guide.requestReference(source, vm.interface, deep)
+ if DEVELOP and os.clock() - clock > 0.1 then
+ log.warn('requestReference', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
+ end
+ vm.mergeResults(results, myResults)
+
+ lock()
+
+ return results
+end
+
+function vm.getRefs(source, deep)
+ if guide.isGlobal(source) then
+ local key = guide.getKeyName(source)
+ return vm.getGlobals(key)
+ else
+ local cache = vm.getCache('eachRef')[source]
+ or getRefs(source, deep)
+ if deep then
+ vm.getCache('eachRef')[source] = cache
+ end
+ return cache
+ end
+end
diff --git a/script/vm/getClass.lua b/script/vm/getClass.lua
new file mode 100644
index 00000000..a8bd7e40
--- /dev/null
+++ b/script/vm/getClass.lua
@@ -0,0 +1,62 @@
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+
+local function lookUpDocClass(source)
+ local infers = vm.getInfers(source, 'deep')
+ for _, infer in ipairs(infers) do
+ if infer.source.type == 'doc.class'
+ or infer.source.type == 'doc.type' then
+ return infer.type
+ end
+ end
+end
+
+local function getClass(source, classes, depth, deep)
+ local docClass = lookUpDocClass(source)
+ if docClass then
+ classes[#classes+1] = docClass
+ return
+ end
+ if depth > 3 then
+ return
+ end
+ local value = guide.getObjectValue(source) or source
+ if not deep then
+ if value and value.type == 'string' then
+ classes[#classes+1] = value[1]
+ end
+ else
+ for _, src in ipairs(vm.getFields(value)) do
+ local key = vm.getKeyName(src)
+ if not key then
+ goto CONTINUE
+ end
+ local lkey = key:lower()
+ if lkey == 's|type'
+ or lkey == 's|__name'
+ or lkey == 's|name'
+ or lkey == 's|class' then
+ local value = guide.getObjectValue(src)
+ if value and value.type == 'string' then
+ classes[#classes+1] = value[1]
+ end
+ end
+ ::CONTINUE::
+ end
+ end
+ if #classes ~= 0 then
+ return
+ end
+ vm.eachMeta(source, function (mt)
+ getClass(mt, classes, depth + 1, deep)
+ end)
+end
+
+function vm.getClass(source, deep)
+ local classes = {}
+ getClass(source, classes, 1, deep)
+ if #classes == 0 then
+ return nil
+ end
+ return guide.mergeTypes(classes)
+end
diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua
new file mode 100644
index 00000000..c06efc11
--- /dev/null
+++ b/script/vm/getDocs.lua
@@ -0,0 +1,175 @@
+local files = require 'files'
+local util = require 'utility'
+local guide = require 'parser.guide'
+local vm = require 'vm.vm'
+local config = require 'config'
+
+local function getTypesOfFile(uri)
+ local types = {}
+ local ast = files.getAst(uri)
+ if not ast or not ast.ast.docs then
+ return types
+ end
+ guide.eachSource(ast.ast.docs, function (src)
+ if src.type == 'doc.type.name'
+ or src.type == 'doc.class.name'
+ or src.type == 'doc.extends.name'
+ or src.type == 'doc.alias.name' then
+ local name = src[1]
+ if name then
+ if not types[name] then
+ types[name] = {}
+ end
+ types[name][#types[name]+1] = src
+ end
+ end
+ end)
+ return types
+end
+
+local function getDocTypes(name)
+ local results = {}
+ for uri in files.eachFile() do
+ local cache = files.getCache(uri)
+ cache.classes = cache.classes or getTypesOfFile(uri)
+ if name == '*' then
+ for _, sources in util.sortPairs(cache.classes) do
+ for _, source in ipairs(sources) do
+ results[#results+1] = source
+ end
+ end
+ else
+ if cache.classes[name] then
+ for _, source in ipairs(cache.classes[name]) do
+ results[#results+1] = source
+ end
+ end
+ end
+ end
+ return results
+end
+
+function vm.getDocEnums(doc, mark, results)
+ mark = mark or {}
+ if mark[doc] then
+ return nil
+ end
+ mark[doc] = true
+ results = results or {}
+ for _, enum in ipairs(doc.enums) do
+ results[#results+1] = enum
+ end
+ for _, resume in ipairs(doc.resumes) do
+ results[#results+1] = resume
+ end
+ for _, unit in ipairs(doc.types) do
+ if unit.type == 'doc.type.name' then
+ for _, other in ipairs(vm.getDocTypes(unit[1])) do
+ if other.type == 'doc.alias.name' then
+ vm.getDocEnums(other.parent.extends, mark, results)
+ end
+ end
+ end
+ end
+ return results
+end
+
+function vm.getDocTypes(name)
+ local cache = vm.getCache('getDocTypes')[name]
+ if cache ~= nil then
+ return cache
+ end
+ cache = getDocTypes(name)
+ vm.getCache('getDocTypes')[name] = cache
+ return cache
+end
+
+function vm.isMetaFile(uri)
+ local status = files.getAst(uri)
+ if not status then
+ return false
+ end
+ local cache = files.getCache(uri)
+ if cache.isMeta ~= nil then
+ return cache.isMeta
+ end
+ cache.isMeta = false
+ if not status.ast.docs then
+ return false
+ end
+ for _, doc in ipairs(status.ast.docs) do
+ if doc.type == 'doc.meta' then
+ cache.isMeta = true
+ return true
+ end
+ end
+ return false
+end
+
+function vm.getValidVersions(doc)
+ if doc.type ~= 'doc.version' then
+ return
+ end
+ local valids = {
+ ['Lua 5.1'] = false,
+ ['Lua 5.2'] = false,
+ ['Lua 5.3'] = false,
+ ['Lua 5.4'] = false,
+ ['LuaJIT'] = false,
+ }
+ for _, version in ipairs(doc.versions) do
+ if version.ge and type(version.version) == 'number' then
+ for ver in pairs(valids) do
+ local verNumber = tonumber(ver:sub(-3))
+ if verNumber and verNumber >= version.version then
+ valids[ver] = true
+ end
+ end
+ elseif version.le and type(version.version) == 'number' then
+ for ver in pairs(valids) do
+ local verNumber = tonumber(ver:sub(-3))
+ if verNumber and verNumber <= version.version then
+ valids[ver] = true
+ end
+ end
+ elseif type(version.version) == 'number' then
+ valids[('Lua %.1f'):format(version.version)] = true
+ elseif 'JIT' == version.version then
+ valids['LuaJIT'] = true
+ end
+ end
+ if valids['Lua 5.1'] then
+ valids['LuaJIT'] = true
+ end
+ return valids
+end
+
+local function isDeprecated(value)
+ if not value.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(value.bindDocs) do
+ if doc.type == 'doc.deprecated' then
+ return true
+ elseif doc.type == 'doc.version' then
+ local valids = vm.getValidVersions(doc)
+ if not valids[config.config.runtime.version] then
+ return true
+ end
+ end
+ end
+ return false
+end
+
+function vm.isDeprecated(value)
+ local defs = vm.getDefs(value, 'deep')
+ if #defs == 0 then
+ return false
+ end
+ for _, def in ipairs(defs) do
+ if not isDeprecated(def) then
+ return false
+ end
+ end
+ return true
+end
diff --git a/script/vm/getGlobals.lua b/script/vm/getGlobals.lua
new file mode 100644
index 00000000..83d6a5e6
--- /dev/null
+++ b/script/vm/getGlobals.lua
@@ -0,0 +1,192 @@
+local guide = require 'parser.guide'
+local vm = require 'vm.vm'
+local files = require 'files'
+local util = require 'utility'
+local config = require 'config'
+
+local function getGlobalsOfFile(uri)
+ local cache = files.getCache(uri)
+ if cache.globals then
+ return cache.globals
+ end
+ local globals = {}
+ cache.globals = globals
+ local ast = files.getAst(uri)
+ if not ast then
+ return globals
+ end
+ local results = guide.findGlobals(ast.ast)
+ local mark = {}
+ for _, res in ipairs(results) do
+ if mark[res] then
+ goto CONTINUE
+ end
+ mark[res] = true
+ local name = guide.getSimpleName(res)
+ if name then
+ if not globals[name] then
+ globals[name] = {}
+ end
+ globals[name][#globals[name]+1] = res
+ end
+ ::CONTINUE::
+ end
+ return globals
+end
+
+local function getGlobalSetsOfFile(uri)
+ local cache = files.getCache(uri)
+ if cache.globalSets then
+ return cache.globalSets
+ end
+ local globals = {}
+ cache.globalSets = globals
+ local ast = files.getAst(uri)
+ if not ast then
+ return globals
+ end
+ local results = guide.findGlobals(ast.ast)
+ local mark = {}
+ for _, res in ipairs(results) do
+ if mark[res] then
+ goto CONTINUE
+ end
+ mark[res] = true
+ if vm.isSet(res) then
+ local name = guide.getSimpleName(res)
+ if name then
+ if not globals[name] then
+ globals[name] = {}
+ end
+ globals[name][#globals[name]+1] = res
+ end
+ end
+ ::CONTINUE::
+ end
+ return globals
+end
+
+local function getGlobals(name)
+ local results = {}
+ for uri in files.eachFile() do
+ local globals = getGlobalsOfFile(uri)
+ if name == '*' then
+ for _, sources in util.sortPairs(globals) do
+ for _, source in ipairs(sources) do
+ results[#results+1] = source
+ end
+ end
+ else
+ if globals[name] then
+ for _, source in ipairs(globals[name]) do
+ results[#results+1] = source
+ end
+ end
+ end
+ end
+ return results
+end
+
+local function getGlobalSets(name)
+ local results = {}
+ for uri in files.eachFile() do
+ local globals = getGlobalSetsOfFile(uri)
+ if name == '*' then
+ for _, sources in util.sortPairs(globals) do
+ for _, source in ipairs(sources) do
+ results[#results+1] = source
+ end
+ end
+ else
+ if globals[name] then
+ for _, source in ipairs(globals[name]) do
+ results[#results+1] = source
+ end
+ end
+ end
+ end
+ return results
+end
+
+local function fastGetAnyGlobals()
+ local results = {}
+ local mark = {}
+ for uri in files.eachFile() do
+ --local globalSets = getGlobalsOfFile(uri)
+ --for destName, sources in util.sortPairs(globalSets) do
+ -- if not mark[destName] then
+ -- mark[destName] = true
+ -- results[#results+1] = sources[1]
+ -- end
+ --end
+ local globals = getGlobalsOfFile(uri)
+ for destName, sources in util.sortPairs(globals) do
+ if not mark[destName] then
+ mark[destName] = true
+ results[#results+1] = sources[1]
+ end
+ end
+ end
+ return results
+end
+
+local function fastGetAnyGlobalSets()
+ local results = {}
+ local mark = {}
+ for uri in files.eachFile() do
+ local globals = getGlobalSetsOfFile(uri)
+ for destName, sources in util.sortPairs(globals) do
+ if not mark[destName] then
+ mark[destName] = true
+ results[#results+1] = sources[1]
+ end
+ end
+ end
+ return results
+end
+
+function vm.getGlobals(key)
+ if key == '*' and config.config.intelliSense.fastGlobal then
+ local cache = vm.getCache('fastGetAnyGlobals')[key]
+ if cache ~= nil then
+ return cache
+ end
+ cache = fastGetAnyGlobals()
+ vm.getCache('fastGetAnyGlobals')[key] = cache
+ return cache
+ else
+ local cache = vm.getCache('getGlobals')[key]
+ if cache ~= nil then
+ return cache
+ end
+ cache = getGlobals(key)
+ vm.getCache('getGlobals')[key] = cache
+ return cache
+ end
+end
+
+function vm.getGlobalSets(key)
+ if key == '*' and config.config.intelliSense.fastGlobal then
+ local cache = vm.getCache('fastGetAnyGlobalSets')[key]
+ if cache ~= nil then
+ return cache
+ end
+ cache = fastGetAnyGlobalSets()
+ vm.getCache('fastGetAnyGlobalSets')[key] = cache
+ return cache
+ end
+ local cache = vm.getCache('getGlobalSets')[key]
+ if cache ~= nil then
+ return cache
+ end
+ cache = getGlobalSets(key)
+ vm.getCache('getGlobalSets')[key] = cache
+ return cache
+end
+
+files.watch(function (ev, uri)
+ if ev == 'update' then
+ getGlobalsOfFile(uri)
+ getGlobalSetsOfFile(uri)
+ end
+end)
diff --git a/script/vm/getInfer.lua b/script/vm/getInfer.lua
new file mode 100644
index 00000000..5272c389
--- /dev/null
+++ b/script/vm/getInfer.lua
@@ -0,0 +1,96 @@
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local util = require 'utility'
+local await = require 'await'
+
+NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
+
+--- 是否包含某种类型
+function vm.hasType(source, type, deep)
+ local defs = vm.getDefs(source, deep)
+ for i = 1, #defs do
+ local def = defs[i]
+ local value = guide.getObjectValue(def) or def
+ if value.type == type then
+ return true
+ end
+ end
+ return false
+end
+
+--- 是否包含某种类型
+function vm.hasInferType(source, type, deep)
+ local infers = vm.getInfers(source, deep)
+ for i = 1, #infers do
+ local infer = infers[i]
+ if infer.type == type then
+ return true
+ end
+ end
+ return false
+end
+
+function vm.getInferType(source, deep)
+ local infers = vm.getInfers(source, deep)
+ return guide.viewInferType(infers)
+end
+
+function vm.getInferLiteral(source, deep)
+ local infers = vm.getInfers(source, deep)
+ local literals = {}
+ local mark = {}
+ for _, infer in ipairs(infers) do
+ local value = infer.value
+ if value and not mark[value] then
+ mark[value] = true
+ literals[#literals+1] = util.viewLiteral(value)
+ end
+ end
+ if #literals == 0 then
+ return nil
+ end
+ table.sort(literals)
+ return table.concat(literals, '|')
+end
+
+local function getInfers(source, deep)
+ local results = {}
+ local lock = vm.lock('getInfers', source)
+ if not lock then
+ return results
+ end
+
+ await.delay()
+
+ local clock = os.clock()
+ local myResults, count = guide.requestInfer(source, vm.interface, deep)
+ if DEVELOP and os.clock() - clock > 0.1 then
+ log.warn('requestInfer', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
+ end
+ vm.mergeResults(results, myResults)
+
+ lock()
+
+ return results
+end
+
+--- 获取对象的值
+--- 会尝试穿透函数调用
+function vm.getInfers(source, deep)
+ if guide.isGlobal(source) then
+ local name = guide.getKeyName(source)
+ local cache = vm.getCache('getInfersOfGlobal')[name]
+ or vm.getCache('getInfers')[source]
+ or getInfers(source, 'deep')
+ vm.getCache('getInfersOfGlobal')[name] = cache
+ vm.getCache('getInfers')[source] = cache
+ return cache
+ else
+ local cache = vm.getCache('getInfers')[source]
+ or getInfers(source, deep)
+ if deep then
+ vm.getCache('getInfers')[source] = cache
+ end
+ return cache
+ end
+end
diff --git a/script/vm/getLibrary.lua b/script/vm/getLibrary.lua
new file mode 100644
index 00000000..5803a73b
--- /dev/null
+++ b/script/vm/getLibrary.lua
@@ -0,0 +1,32 @@
+local vm = require 'vm.vm'
+
+function vm.getLibraryName(source, deep)
+ local defs = vm.getDefs(source, deep)
+ for _, def in ipairs(defs) do
+ if def.special then
+ return def.special
+ end
+ end
+ return nil
+end
+
+local globalLibraryNames = {
+ 'arg', 'assert', 'collectgarbage', 'dofile', '_G', 'getfenv',
+ 'getmetatable', 'ipairs', 'load', 'loadfile', 'loadstring',
+ 'module', 'next', 'pairs', 'pcall', 'print', 'rawequal',
+ 'rawget', 'rawlen', 'rawset', 'select', 'setfenv',
+ 'setmetatable', 'tonumber', 'tostring', 'type', '_VERSION',
+ 'warn', 'xpcall', 'require', 'unpack', 'bit32', 'coroutine',
+ 'debug', 'io', 'math', 'os', 'package', 'string', 'table',
+ 'utf8',
+}
+local globalLibraryNamesMap
+function vm.isGlobalLibraryName(name)
+ if not globalLibraryNamesMap then
+ globalLibraryNamesMap = {}
+ for _, v in ipairs(globalLibraryNames) do
+ globalLibraryNamesMap[v] = true
+ end
+ end
+ return globalLibraryNamesMap[name] or false
+end
diff --git a/script/vm/getLinks.lua b/script/vm/getLinks.lua
new file mode 100644
index 00000000..0bb1c6ff
--- /dev/null
+++ b/script/vm/getLinks.lua
@@ -0,0 +1,61 @@
+local guide = require 'parser.guide'
+local vm = require 'vm.vm'
+local files = require 'files'
+
+local function getFileLinks(uri)
+ local ws = require 'workspace'
+ local links = {}
+ local ast = files.getAst(uri)
+ if not ast then
+ return links
+ end
+ guide.eachSpecialOf(ast.ast, 'require', function (source)
+ local call = source.parent
+ if not call or call.type ~= 'call' then
+ return
+ end
+ local args = call.args
+ if not args[1] or args[1].type ~= 'string' then
+ return
+ end
+ local uris = ws.findUrisByRequirePath(args[1][1])
+ for _, u in ipairs(uris) do
+ u = files.asKey(u)
+ if not links[u] then
+ links[u] = {}
+ end
+ links[u][#links[u]+1] = call
+ end
+ end)
+ return links
+end
+
+local function getLinksTo(uri)
+ uri = files.asKey(uri)
+ local links = {}
+ for u in files.eachFile() do
+ local ls = vm.getFileLinks(u)
+ if ls[uri] then
+ for _, l in ipairs(ls[uri]) do
+ links[#links+1] = l
+ end
+ end
+ end
+ return links
+end
+
+function vm.getLinksTo(uri)
+ local cache = vm.getCache('getLinksTo')[uri]
+ if cache ~= nil then
+ return cache
+ end
+ cache = getLinksTo(uri)
+ vm.getCache('getLinksTo')[uri] = cache
+ return cache
+end
+
+function vm.getFileLinks(uri)
+ local cache = files.getCache(uri)
+ cache.links = cache.links or getFileLinks(uri)
+ return cache.links
+end
diff --git a/script/vm/getMeta.lua b/script/vm/getMeta.lua
new file mode 100644
index 00000000..aebef1a7
--- /dev/null
+++ b/script/vm/getMeta.lua
@@ -0,0 +1,52 @@
+local vm = require 'vm.vm'
+
+local function eachMetaOfArg1(source, callback)
+ local node, index = vm.getArgInfo(source)
+ local special = vm.getSpecial(node)
+ if special == 'setmetatable' and index == 1 then
+ local mt = node.next.args[2]
+ if mt then
+ callback(mt)
+ end
+ end
+end
+
+local function eachMetaOfRecv(source, callback)
+ if not source or source.type ~= 'select' then
+ return
+ end
+ if source.index ~= 1 then
+ return
+ end
+ local call = source.vararg
+ if not call or call.type ~= 'call' then
+ return
+ end
+ local special = vm.getSpecial(call.node)
+ if special ~= 'setmetatable' then
+ return
+ end
+ local mt = call.args[2]
+ if mt then
+ callback(mt)
+ end
+end
+
+function vm.eachMetaValue(source, callback)
+ vm.eachMeta(source, function (mt)
+ for _, src in ipairs(vm.getFields(mt)) do
+ if vm.getKeyName(src) == 's|__index' then
+ if src.value then
+ for _, valueSrc in ipairs(vm.getFields(src.value)) do
+ callback(valueSrc)
+ end
+ end
+ end
+ end
+ end)
+end
+
+function vm.eachMeta(source, callback)
+ eachMetaOfArg1(source, callback)
+ eachMetaOfRecv(source.value, callback)
+end
diff --git a/script/vm/guideInterface.lua b/script/vm/guideInterface.lua
new file mode 100644
index 00000000..e646def8
--- /dev/null
+++ b/script/vm/guideInterface.lua
@@ -0,0 +1,106 @@
+local vm = require 'vm.vm'
+local files = require 'files'
+local ws = require 'workspace'
+local guide = require 'parser.guide'
+local await = require 'await'
+local config = require 'config'
+
+local m = {}
+
+function m.searchFileReturn(results, ast, index)
+ local returns = ast.returns
+ if not returns then
+ return
+ end
+ for _, ret in ipairs(returns) do
+ local exp = ret[index]
+ if exp then
+ vm.mergeResults(results, { exp })
+ end
+ end
+end
+
+function m.require(args, index)
+ local reqName = args[1] and args[1][1]
+ if not reqName then
+ return nil
+ end
+ local results = {}
+ local myUri = guide.getUri(args[1])
+ local uris = ws.findUrisByRequirePath(reqName)
+ for _, uri in ipairs(uris) do
+ if not files.eq(myUri, uri) then
+ local ast = files.getAst(uri)
+ if ast then
+ m.searchFileReturn(results, ast.ast, index)
+ end
+ end
+ end
+
+ return results
+end
+
+function m.dofile(args, index)
+ local reqName = args[1] and args[1][1]
+ if not reqName then
+ return
+ end
+ local results = {}
+ local myUri = guide.getUri(args[1])
+ local uris = ws.findUrisByFilePath(reqName)
+ for _, uri in ipairs(uris) do
+ if not files.eq(myUri, uri) then
+ local ast = files.getAst(uri)
+ if ast then
+ m.searchFileReturn(results, ast.ast, index)
+ end
+ end
+ end
+ return results
+end
+
+vm.interface = {}
+
+-- 向前寻找引用的层数限制,一般情况下都为0
+-- 在自动完成/漂浮提示等情况时设置为5(需要清空缓存)
+-- 在查找引用时设置为10(需要清空缓存)
+vm.interface.searchLevel = 0
+
+function vm.interface.call(func, args, index)
+ if func.special == 'require' and index == 1 then
+ await.delay()
+ return m.require(args, index)
+ end
+ if func.special == 'dofile' then
+ await.delay()
+ return m.dofile(args, index)
+ end
+end
+
+function vm.interface.global(name)
+ await.delay()
+ return vm.getGlobals(name)
+end
+
+function vm.interface.docType(name)
+ await.delay()
+ return vm.getDocTypes(name)
+end
+
+function vm.interface.link(uri)
+ await.delay()
+ return vm.getLinksTo(uri)
+end
+
+function vm.interface.index(obj)
+ return nil
+end
+
+function vm.interface.cache()
+ await.delay()
+ return vm.getCache('cache')
+end
+
+function vm.interface.getSearchDepth()
+ return config.config.intelliSense.searchDepth
+end
diff --git a/script/vm/init.lua b/script/vm/init.lua
new file mode 100644
index 00000000..b9e8e147
--- /dev/null
+++ b/script/vm/init.lua
@@ -0,0 +1,13 @@
+local vm = require 'vm.vm'
+require 'vm.getGlobals'
+require 'vm.getDocs'
+require 'vm.getLibrary'
+require 'vm.getInfer'
+require 'vm.getClass'
+require 'vm.getMeta'
+require 'vm.eachField'
+require 'vm.eachDef'
+require 'vm.eachRef'
+require 'vm.getLinks'
+require 'vm.guideInterface'
+return vm
diff --git a/script/vm/vm.lua b/script/vm/vm.lua
new file mode 100644
index 00000000..e942d55e
--- /dev/null
+++ b/script/vm/vm.lua
@@ -0,0 +1,167 @@
+local guide = require 'parser.guide'
+local util = require 'utility'
+local files = require 'files'
+local timer = require 'timer'
+
+local setmetatable = setmetatable
+local assert = assert
+local require = require
+local type = type
+local running = coroutine.running
+local ipairs = ipairs
+local log = log
+local xpcall = xpcall
+local mathHuge = math.huge
+local collectgarbage = collectgarbage
+
+_ENV = nil
+
+---@class vm
+local m = {}
+
+function m.lock(tp, source)
+ local co = running()
+ local master = m.locked[co]
+ if not master then
+ master = {}
+ m.locked[co] = master
+ end
+ if not master[tp] then
+ master[tp] = {}
+ end
+ if master[tp][source] then
+ return nil
+ end
+ master[tp][source] = true
+ return function ()
+ master[tp][source] = nil
+ end
+end
+
+function m.isSet(src)
+ local tp = src.type
+ if tp == 'setglobal'
+ or tp == 'local'
+ or tp == 'setlocal'
+ or tp == 'setfield'
+ or tp == 'setmethod'
+ or tp == 'setindex'
+ or tp == 'tablefield'
+ or tp == 'tableindex' then
+ return true
+ end
+ if tp == 'call' then
+ local special = m.getSpecial(src.node)
+ if special == 'rawset' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.isGet(src)
+ local tp = src.type
+ if tp == 'getglobal'
+ or tp == 'getlocal'
+ or tp == 'getfield'
+ or tp == 'getmethod'
+ or tp == 'getindex' then
+ return true
+ end
+ if tp == 'call' then
+ local special = m.getSpecial(src.node)
+ if special == 'rawget' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.getArgInfo(source)
+ local callargs = source.parent
+ if not callargs or callargs.type ~= 'callargs' then
+ return nil
+ end
+ local call = callargs.parent
+ if not call or call.type ~= 'call' then
+ return nil
+ end
+ for i = 1, #callargs do
+ if callargs[i] == source then
+ return call.node, i
+ end
+ end
+ return nil
+end
+
+function m.getSpecial(source)
+ if not source then
+ return nil
+ end
+ return source.special
+end
+
+function m.getKeyName(source)
+ if not source then
+ return nil
+ end
+ if source.type == 'call' then
+ local special = m.getSpecial(source.node)
+ if special == 'rawset'
+ or special == 'rawget' then
+ return guide.getKeyNameOfLiteral(source.args[2])
+ end
+ end
+ return guide.getKeyName(source)
+end
+
+function m.mergeResults(a, b)
+ for _, r in ipairs(b) do
+ if not a[r] then
+ a[r] = true
+ a[#a+1] = r
+ end
+ end
+ return a
+end
+
+m.cacheTracker = setmetatable({}, { __mode = 'kv' })
+
+function m.flushCache()
+ if m.cache then
+ m.cache.dead = true
+ end
+ m.cacheVersion = files.globalVersion
+ m.cache = {}
+ m.cacheActiveTime = mathHuge
+ m.locked = setmetatable({}, { __mode = 'k' })
+ m.cacheTracker[m.cache] = true
+end
+
+function m.getCache(name)
+ if m.cacheVersion ~= files.globalVersion then
+ m.flushCache()
+ end
+ m.cacheActiveTime = timer.clock()
+ if not m.cache[name] then
+ m.cache[name] = {}
+ end
+ return m.cache[name]
+end
+
+local function init()
+ m.flushCache()
+
+ -- 可以在一段时间不活动后清空缓存,不过目前看起来没有必要
+ --timer.loop(1, function ()
+ -- if timer.clock() - m.cacheActiveTime > 10.0 then
+ -- log.info('Flush cache: Inactive')
+ -- m.flushCache()
+ -- collectgarbage()
+ -- end
+ --end)
+end
+
+xpcall(init, log.error)
+
+return m
diff --git a/script/without-check-nil.lua b/script/without-check-nil.lua
new file mode 100644
index 00000000..cc7da9d4
--- /dev/null
+++ b/script/without-check-nil.lua
@@ -0,0 +1,126 @@
+local m = {}
+
+local mt = {}
+mt.__add = function (a, b)
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a + b
+end
+mt.__sub = function (a, b)
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a - b
+end
+mt.__mul = function (a, b)
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a * b
+end
+mt.__div = function (a, b)
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a / b
+end
+mt.__mod = function (a, b)
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a % b
+end
+mt.__pow = function (a, b)
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a ^ b
+end
+mt.__unm = function ()
+ return 0
+end
+mt.__concat = function (a, b)
+ if a == nil then a = '' end
+ if b == nil then b = '' end
+ return a .. b
+end
+mt.__len = function ()
+ return 0
+end
+mt.__lt = function (a, b)
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a < b
+end
+mt.__le = function (a, b)
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a <= b
+end
+mt.__index = function () end
+mt.__newindex = function () end
+mt.__call = function () end
+mt.__pairs = function () end
+mt.__ipairs = function () end
+if _VERSION == 'Lua 5.3' or _VERSION == 'Lua 5.4' then
+ mt.__idiv = load[[
+ local a, b = ...
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a // b
+ ]]
+ mt.__band = load[[
+ local a, b = ...
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a & b
+ ]]
+ mt.__bor = load[[
+ local a, b = ...
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a | b
+ ]]
+ mt.__bxor = load[[
+ local a, b = ...
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a ~ b
+ ]]
+ mt.__bnot = load[[
+ return ~ 0
+ ]]
+ mt.__shl = load[[
+ local a, b = ...
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a << b
+ ]]
+ mt.__shr = load[[
+ local a, b = ...
+ if a == nil then a = 0 end
+ if b == nil then b = 0 end
+ return a >> b
+ ]]
+end
+
+for event, func in pairs(mt) do
+ mt[event] = function (...)
+ local watch = m.watch
+ if not watch then
+ return func(...)
+ end
+ local care, result = watch(event, ...)
+ if not care then
+ return func(...)
+ end
+ return result
+ end
+end
+
+function m.enable()
+ debug.setmetatable(nil, mt)
+end
+
+function m.disable()
+ if debug.getmetatable(nil) == mt then
+ debug.setmetatable(nil, nil)
+ end
+end
+
+return m
diff --git a/script/workspace/init.lua b/script/workspace/init.lua
new file mode 100644
index 00000000..7cbe15d7
--- /dev/null
+++ b/script/workspace/init.lua
@@ -0,0 +1,3 @@
+local workspace = require 'workspace.workspace'
+
+return workspace
diff --git a/script/workspace/require-path.lua b/script/workspace/require-path.lua
new file mode 100644
index 00000000..cfdc0455
--- /dev/null
+++ b/script/workspace/require-path.lua
@@ -0,0 +1,74 @@
+local platform = require 'bee.platform'
+local files = require 'files'
+local furi = require 'file-uri'
+local m = {}
+
+m.cache = {}
+
+--- `aaa/bbb/ccc.lua` 与 `?.lua` 将返回 `aaa.bbb.cccc`
+local function getOnePath(path, searcher)
+ local stemPath = path
+ : gsub('%.[^%.]+$', '')
+ : gsub('[/\\]+', '.')
+ local stemSearcher = searcher
+ : gsub('%.[^%.]+$', '')
+ : gsub('[/\\]+', '.')
+ local start = stemSearcher:match '()%?' or 1
+ for pos = start, #stemPath do
+ local word = stemPath:sub(start, pos)
+ local newSearcher = stemSearcher:gsub('%?', word)
+ if newSearcher == stemPath then
+ return word
+ end
+ end
+ return nil
+end
+
+function m.getVisiblePath(path, searchers)
+ path = path:gsub('^[/\\]+', '')
+ local uri = furi.encode(path)
+ local libraryPath = files.getLibraryPath(uri)
+ if not m.cache[path] then
+ local result = {}
+ m.cache[path] = result
+ local pos = 1
+ if libraryPath then
+ pos = #libraryPath + 2
+ end
+ repeat
+ local cutedPath = path:sub(pos)
+ local head
+ if pos > 1 then
+ head = path:sub(1, pos - 1)
+ end
+ pos = path:match('[/\\]+()', pos)
+ for _, searcher in ipairs(searchers) do
+ if platform.OS == 'Windows' then
+ searcher = searcher:gsub('[/\\]+', '\\')
+ else
+ searcher = searcher:gsub('[/\\]+', '/')
+ end
+ local expect = getOnePath(cutedPath, searcher)
+ if expect then
+ if head then
+ searcher = head .. searcher
+ end
+ result[#result+1] = {
+ searcher = searcher,
+ expect = expect,
+ }
+ end
+ end
+ if not pos then
+ break
+ end
+ until not pos
+ end
+ return m.cache[path]
+end
+
+function m.flush()
+ m.cache = {}
+end
+
+return m
diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua
new file mode 100644
index 00000000..96b982b8
--- /dev/null
+++ b/script/workspace/workspace.lua
@@ -0,0 +1,330 @@
+local pub = require 'pub'
+local fs = require 'bee.filesystem'
+local furi = require 'file-uri'
+local files = require 'files'
+local config = require 'config'
+local glob = require 'glob'
+local platform = require 'bee.platform'
+local await = require 'await'
+local rpath = require 'workspace.require-path'
+local proto = require 'proto.proto'
+local lang = require 'language'
+local library = require 'library'
+local sp = require 'bee.subprocess'
+
+local m = {}
+m.type = 'workspace'
+m.nativeVersion = -1
+m.libraryVersion = -1
+m.nativeMatcher = nil
+m.requireCache = {}
+m.matchOption = {
+ ignoreCase = platform.OS == 'Windows',
+}
+
+--- 初始化工作区
+function m.init(uri)
+ log.info('Workspace inited: ', uri)
+ if not uri then
+ return
+ end
+ m.uri = uri
+ m.path = m.normalize(furi.decode(uri))
+ local logPath = ROOT / 'log' / (uri:gsub('[/:]+', '_') .. '.log')
+ log.info('Log path: ', logPath)
+ log.init(ROOT, logPath)
+end
+
+local function interfaceFactory(root)
+ return {
+ type = function (path)
+ if fs.is_directory(fs.path(root .. '/' .. path)) then
+ return 'directory'
+ else
+ return 'file'
+ end
+ end,
+ list = function (path)
+ local fullPath = fs.path(root .. '/' .. path)
+ if not fs.exists(fullPath) then
+ return nil
+ end
+ local paths = {}
+ pcall(function ()
+ for fullpath in fullPath:list_directory() do
+ paths[#paths+1] = fullpath:string()
+ end
+ end)
+ return paths
+ end
+ }
+end
+
+--- 创建排除文件匹配器
+function m.getNativeMatcher()
+ if not m.path then
+ return nil
+ end
+ if m.nativeVersion == config.version then
+ return m.nativeMatcher
+ end
+
+ local interface = interfaceFactory(m.path)
+ local pattern = {}
+ -- config.workspace.ignoreDir
+ for path in pairs(config.config.workspace.ignoreDir) do
+ log.info('Ignore directory:', path)
+ pattern[#pattern+1] = path
+ end
+ -- config.files.exclude
+ for path, ignore in pairs(config.other.exclude) do
+ if ignore then
+ log.info('Ignore by exclude:', path)
+ pattern[#pattern+1] = path
+ end
+ end
+ -- config.workspace.ignoreSubmodules
+ if config.config.workspace.ignoreSubmodules then
+ local buf = pub.awaitTask('loadFile', furi.encode(m.path .. '/.gitmodules'))
+ if buf then
+ for path in buf:gmatch('path = ([^\r\n]+)') do
+ log.info('Ignore by .gitmodules:', path)
+ pattern[#pattern+1] = path
+ end
+ end
+ end
+ -- config.workspace.useGitIgnore
+ if config.config.workspace.useGitIgnore then
+ local buf = pub.awaitTask('loadFile', furi.encode(m.path .. '/.gitignore'))
+ if buf then
+ for line in buf:gmatch '[^\r\n]+' do
+ log.info('Ignore by .gitignore:', line)
+ pattern[#pattern+1] = line
+ end
+ end
+ buf = pub.awaitTask('loadFile', furi.encode(m.path .. '/.git/info/exclude'))
+ if buf then
+ for line in buf:gmatch '[^\r\n]+' do
+ log.info('Ignore by .git/info/exclude:', line)
+ pattern[#pattern+1] = line
+ end
+ end
+ end
+ -- config.workspace.library
+ for path in pairs(config.config.workspace.library) do
+ log.info('Ignore by library:', path)
+ pattern[#pattern+1] = path
+ end
+
+ m.nativeMatcher = glob.gitignore(pattern, m.matchOption, interface)
+
+ m.nativeVersion = config.version
+ return m.nativeMatcher
+end
+
+--- 创建代码库筛选器
+function m.getLibraryMatchers()
+ if m.libraryVersion == config.version then
+ return m.libraryMatchers
+ end
+
+ local librarys = {}
+ for path, pattern in pairs(config.config.workspace.library) do
+ librarys[path] = pattern
+ end
+ if library.metaPath then
+ librarys[library.metaPath] = true
+ end
+ m.libraryMatchers = {}
+ for path, pattern in pairs(librarys) do
+ local nPath = fs.absolute(fs.path(path)):string()
+ local matcher = glob.gitignore(pattern, m.matchOption)
+ if platform.OS == 'Windows' then
+ matcher:setOption 'ignoreCase'
+ end
+ log.debug('getLibraryMatchers', path, nPath)
+ m.libraryMatchers[#m.libraryMatchers+1] = {
+ path = nPath,
+ matcher = matcher
+ }
+ end
+
+ m.libraryVersion = config.version
+ return m.libraryMatchers
+end
+
+--- 文件是否被忽略
+function m.isIgnored(uri)
+ local path = furi.decode(uri)
+ local ignore = m.getNativeMatcher()
+ if not ignore then
+ return false
+ end
+ return ignore(path)
+end
+
+local function loadFileFactory(root, progress, isLibrary)
+ return function (path)
+ local uri = furi.encode(root .. '/' .. path)
+ if not files.isLua(uri) then
+ return
+ end
+ if progress.preload >= config.config.workspace.maxPreload then
+ if not m.hasHitMaxPreload then
+ m.hasHitMaxPreload = true
+ proto.notify('window/showMessage', {
+ type = 3,
+ message = lang.script('MWS_MAX_PRELOAD', config.config.workspace.maxPreload),
+ })
+ end
+ return
+ end
+ if not isLibrary then
+ progress.preload = progress.preload + 1
+ end
+ progress.max = progress.max + 1
+ pub.task('loadFile', uri, function (text)
+ progress.read = progress.read + 1
+ --log.info(('Preload file at: %s , size = %.3f KB'):format(uri, #text / 1000.0))
+ if isLibrary then
+ files.setLibraryPath(uri, root)
+ end
+ files.setText(uri, text)
+ end)
+ end
+end
+
+--- 预读工作区内所有文件
+function m.awaitPreload()
+ await.close 'preload'
+ await.setID 'preload'
+ local progress = {
+ max = 0,
+ read = 0,
+ preload = 0,
+ }
+ log.info('Preload start.')
+ local nativeLoader = loadFileFactory(m.path, progress)
+ local native = m.getNativeMatcher()
+ local librarys = m.getLibraryMatchers()
+ if native then
+ native:scan(nativeLoader)
+ end
+ for _, library in ipairs(librarys) do
+ local libraryInterface = interfaceFactory(library.path)
+ local libraryLoader = loadFileFactory(library.path, progress, true)
+ for k, v in pairs(libraryInterface) do
+ library.matcher:setInterface(k, v)
+ end
+ library.matcher:scan(libraryLoader)
+ end
+
+ log.info(('Found %d files.'):format(progress.max))
+ while true do
+ log.info(('Loaded %d/%d files'):format(progress.read, progress.max))
+ if progress.read >= progress.max then
+ break
+ end
+ await.sleep(0.1)
+ end
+
+ --for i = 1, 100 do
+ -- await.sleep(0.1)
+ -- log.info('sleep', i)
+ --end
+
+ log.info('Preload finish.')
+
+ local diagnostic = require 'provider.diagnostic'
+ diagnostic.start()
+end
+
+--- 查找符合指定file path的所有uri
+---@param path string
+function m.findUrisByFilePath(path)
+ if type(path) ~= 'string' then
+ return {}
+ end
+ local results = {}
+ local posts = {}
+ for uri in files.eachFile() do
+ local pathLen = #path
+ local uriLen = #uri
+ local seg = uri:sub(uriLen - pathLen, uriLen - pathLen)
+ if seg == '/' or seg == '\\' or seg == '' then
+ local see = uri:sub(uriLen - pathLen + 1, uriLen)
+ if files.eq(see, path) then
+ results[#results+1] = uri
+ posts[uri] = files.getOriginUri(uri):sub(1, uriLen - pathLen)
+ end
+ end
+ end
+ return results, posts
+end
+
+--- 查找符合指定require path的所有uri
+---@param path string
+function m.findUrisByRequirePath(path)
+ if type(path) ~= 'string' then
+ return {}
+ end
+ local results = {}
+ local mark = {}
+ local searchers = {}
+ local input = path:gsub('%.', '/')
+ :gsub('%%', '%%%%')
+ for _, luapath in ipairs(config.config.runtime.path) do
+ local part = luapath:gsub('%?', input)
+ local uris, posts = m.findUrisByFilePath(part)
+ for _, uri in ipairs(uris) do
+ if not mark[uri] then
+ mark[uri] = true
+ results[#results+1] = uri
+ searchers[uri] = posts[uri] .. luapath
+ end
+ end
+ end
+ return results, searchers
+end
+
+function m.normalize(path)
+ if platform.OS == 'Windows' then
+ path = path:gsub('[/\\]+', '\\')
+ :gsub('^%a+%:', function (str)
+ return str:upper()
+ end)
+ else
+ path = path:gsub('[/\\]+', '/')
+ end
+ return path:gsub('^[/\\]+', '')
+end
+
+function m.getRelativePath(uri)
+ local path = furi.decode(uri)
+ if not m.path then
+ return m.normalize(path)
+ end
+ local _, pos = m.normalize(path):lower():find(m.path:lower(), 1, true)
+ if pos then
+ return m.normalize(path:sub(pos + 1))
+ else
+ return m.normalize(path)
+ end
+end
+
+function m.reload()
+ files.flushAllLibrary()
+ files.removeAllClosed()
+ rpath.flush()
+ await.call(m.awaitPreload)
+end
+
+files.watch(function (ev, uri)
+ if ev == 'close'
+ and m.isIgnored(uri)
+ and not files.isLibrary(uri) then
+ files.remove(uri)
+ end
+end)
+
+return m