summaryrefslogtreecommitdiff
path: root/script/core
diff options
context:
space:
mode:
Diffstat (limited to 'script/core')
-rw-r--r--script/core/code_action.lua410
-rw-r--r--script/core/completion.lua1079
-rw-r--r--script/core/definition.lua296
-rw-r--r--script/core/diagnostics.lua1042
-rw-r--r--script/core/document_symbol.lua260
-rw-r--r--script/core/find_lib.lua65
-rw-r--r--script/core/find_source.lua59
-rw-r--r--script/core/folding_range.lua73
-rw-r--r--script/core/global.lua49
-rw-r--r--script/core/highlight.lua54
-rw-r--r--script/core/hover/emmy_function.lua143
-rw-r--r--script/core/hover/function.lua243
-rw-r--r--script/core/hover/hover.lua326
-rw-r--r--script/core/hover/init.lua1
-rw-r--r--script/core/hover/lib_function.lua222
-rw-r--r--script/core/hover/name.lua38
-rw-r--r--script/core/implementation.lua204
-rw-r--r--script/core/init.lua19
-rw-r--r--script/core/library.lua296
-rw-r--r--script/core/matchKey.lua30
-rw-r--r--script/core/name.lua70
-rw-r--r--script/core/references.lua91
-rw-r--r--script/core/rename.lua72
-rw-r--r--script/core/signature.lua133
-rw-r--r--script/core/snippet.lua64
25 files changed, 5339 insertions, 0 deletions
diff --git a/script/core/code_action.lua b/script/core/code_action.lua
new file mode 100644
index 00000000..2c1fb14d
--- /dev/null
+++ b/script/core/code_action.lua
@@ -0,0 +1,410 @@
+local lang = require 'language'
+local library = require 'core.library'
+
+local function disableDiagnostic(lsp, uri, data, callback)
+ callback {
+ title = lang.script('ACTION_DISABLE_DIAG', data.code),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_DISABLE_DIAG,
+ command = 'config',
+ arguments = {
+ {
+ key = {'diagnostics', 'disable'},
+ action = 'add',
+ value = data.code,
+ }
+ }
+ }
+ }
+end
+
+local function addGlobal(name, callback)
+ callback {
+ title = lang.script('ACTION_MARK_GLOBAL', name),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_MARK_GLOBAL,
+ command = 'config',
+ arguments = {
+ {
+ key = {'diagnostics', 'globals'},
+ action = 'add',
+ value = name,
+ }
+ }
+ },
+ }
+end
+
+local function changeVersion(version, callback)
+ callback {
+ title = lang.script('ACTION_RUNTIME_VERSION', version),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_RUNTIME_VERSION,
+ command = 'config',
+ arguments = {
+ {
+ key = {'runtime', 'version'},
+ action = 'set',
+ value = version,
+ }
+ }
+ },
+ }
+end
+
+local function openCustomLibrary(libName, callback)
+ callback {
+ title = lang.script('ACTION_OPEN_LIBRARY', libName),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_OPEN_LIBRARY,
+ command = 'config',
+ arguments = {
+ {
+ key = {'runtime', 'library'},
+ action = 'add',
+ value = libName,
+ }
+ }
+ },
+ }
+end
+
+local function solveUndefinedGlobal(lsp, uri, data, callback)
+ local vm, lines, text = lsp:getVM(uri)
+ if not vm then
+ return
+ end
+ local start = lines:position(data.range.start.line + 1, data.range.start.character + 1)
+ local finish = lines:position(data.range['end'].line + 1, data.range['end'].character)
+ local name = text:sub(start, finish)
+ if #name < 0 or name:find('[^%w_]') then
+ return
+ end
+ addGlobal(name, callback)
+ local otherVersion = library.other[name]
+ if otherVersion then
+ for _, version in ipairs(otherVersion) do
+ changeVersion(version, callback)
+ end
+ end
+
+ local customLibrary = library.custom[name]
+ if customLibrary then
+ for _, libName in ipairs(customLibrary) do
+ openCustomLibrary(libName, callback)
+ end
+ end
+end
+
+local function solveLowercaseGlobal(lsp, uri, data, callback)
+ local vm, lines, text = lsp:getVM(uri)
+ if not vm then
+ return
+ end
+ local start = lines:position(data.range.start.line + 1, data.range.start.character + 1)
+ local finish = lines:position(data.range['end'].line + 1, data.range['end'].character)
+ local name = text:sub(start, finish)
+ if #name < 0 or name:find('[^%w_]') then
+ return
+ end
+ addGlobal(name, callback)
+end
+
+local function solveTrailingSpace(lsp, uri, data, callback)
+ callback {
+ title = lang.script.ACTION_REMOVE_SPACE,
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_REMOVE_SPACE,
+ command = 'removeSpace',
+ arguments = {
+ {
+ uri = uri,
+ }
+ }
+ },
+ }
+end
+
+local function solveNewlineCall(lsp, uri, data, callback)
+ callback {
+ title = lang.script.ACTION_ADD_SEMICOLON,
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ range = {
+ start = data.range.start,
+ ['end'] = data.range.start,
+ },
+ newText = ';',
+ }
+ }
+ }
+ }
+ }
+end
+
+local function solveAmbiguity1(lsp, uri, data, callback)
+ callback {
+ title = lang.script.ACTION_ADD_BRACKETS,
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_ADD_BRACKETS,
+ command = 'solve',
+ arguments = {
+ {
+ name = 'ambiguity-1',
+ uri = uri,
+ range = data.range,
+ }
+ }
+ },
+ }
+end
+
+local function findSyntax(astErr, lines, data)
+ local start = lines:position(data.range.start.line + 1, data.range.start.character + 1)
+ local finish = lines:position(data.range['end'].line + 1, data.range['end'].character)
+ for _, err in ipairs(astErr) do
+ if err.start == start and err.finish == finish then
+ return err
+ end
+ end
+ return nil
+end
+
+local function solveSyntaxByChangeVersion(err, callback)
+ if type(err.version) == 'table' then
+ for _, version in ipairs(err.version) do
+ changeVersion(version, callback)
+ end
+ else
+ changeVersion(err.version, callback)
+ end
+end
+
+local function solveSyntaxByAddDoEnd(uri, data, callback)
+ callback {
+ title = lang.script.ACTION_ADD_DO_END,
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ range = {
+ start = data.range.start,
+ ['end'] = data.range.start,
+ },
+ newText = 'do ',
+ },
+ {
+ range = {
+ start = data.range['end'],
+ ['end'] = data.range['end'],
+ },
+ newText = ' end',
+ }
+ }
+ }
+ }
+ }
+end
+
+local function solveSyntaxByFix(uri, err, lines, callback)
+ local changes = {}
+ for _, e in ipairs(err.fix) do
+ local start_row, start_col = lines:rowcol(e.start)
+ local finish_row, finish_col = lines:rowcol(e.finish)
+ changes[#changes+1] = {
+ range = {
+ start = {
+ line = start_row - 1,
+ character = start_col - 1,
+ },
+ ['end'] = {
+ line = finish_row - 1,
+ character = finish_col,
+ },
+ },
+ newText = e.text,
+ }
+ end
+ callback {
+ title = lang.script['ACTION_' .. err.fix.title],
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = changes,
+ }
+ }
+ }
+end
+
+local function findEndPosition(lines, row, endrow)
+ if endrow == row then
+ return {
+ newText = ' end',
+ range = {
+ start = {
+ line = row - 1,
+ character = 999999,
+ },
+ ['end'] = {
+ line = row - 1,
+ character = 999999,
+ }
+ }
+ }
+ else
+ local l = lines[row]
+ return {
+ newText = ('\t'):rep(l.tab) .. (' '):rep(l.sp) .. 'end\n',
+ range = {
+ start = {
+ line = endrow,
+ character = 0,
+ },
+ ['end'] = {
+ line = endrow,
+ character = 0,
+ }
+ }
+ }
+ end
+end
+
+local function isIfPart(id, lines, i)
+ if id ~= 'if' then
+ return false
+ end
+ local buf = lines:line(i)
+ local first = buf:match '^[%s\t]*([%w]+)'
+ if first == 'else' or first == 'elseif' then
+ return true
+ end
+ return false
+end
+
+local function solveSyntaxByAddEnd(uri, start, finish, lines, callback)
+ local row = lines:rowcol(start)
+ local line = lines[row]
+ if not line then
+ return nil
+ end
+ local id = lines.buf:sub(start, finish)
+ local sp = line.sp + line.tab * 4
+ for i = row + 1, #lines do
+ local nl = lines[i]
+ local lsp = nl.sp + nl.tab * 4
+ if lsp <= sp and not isIfPart(id, lines, i) then
+ callback {
+ title = lang.script['ACTION_ADD_END'],
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = {
+ findEndPosition(lines, row, i - 1)
+ }
+ }
+ }
+ }
+ return
+ end
+ end
+ return nil
+end
+
+---@param lsp LSP
+---@param uri uri
+---@param data table
+---@param callback function
+local function solveSyntax(lsp, uri, data, callback)
+ local file = lsp:getFile(uri)
+ if not file then
+ return
+ end
+ local astErr, lines = file:getAstErr(), file:getLines()
+ if not astErr or not lines then
+ return
+ end
+ local err = findSyntax(astErr, lines, data)
+ if not err then
+ return nil
+ end
+ if err.version then
+ solveSyntaxByChangeVersion(err, callback)
+ end
+ if err.type == 'ACTION_AFTER_BREAK' or err.type == 'ACTION_AFTER_RETURN' then
+ solveSyntaxByAddDoEnd(uri, data, callback)
+ end
+ if err.type == 'MISS_END' then
+ solveSyntaxByAddEnd(uri, err.start, err.finish, lines, callback)
+ end
+ if err.type == 'MISS_SYMBOL' and err.info.symbol == 'end' then
+ solveSyntaxByAddEnd(uri, err.info.related[1], err.info.related[2], lines, callback)
+ end
+ if err.fix then
+ solveSyntaxByFix(uri, err, lines, callback)
+ end
+end
+
+local function solveDiagnostic(lsp, uri, data, callback)
+ if data.source == lang.script.DIAG_SYNTAX_CHECK then
+ solveSyntax(lsp, uri, data, callback)
+ end
+ if not data.code then
+ return
+ end
+ if data.code == 'undefined-global' then
+ solveUndefinedGlobal(lsp, uri, data, callback)
+ end
+ if data.code == 'trailing-space' then
+ solveTrailingSpace(lsp, uri, data, callback)
+ end
+ if data.code == 'newline-call' then
+ solveNewlineCall(lsp, uri, data, callback)
+ end
+ if data.code == 'ambiguity-1' then
+ solveAmbiguity1(lsp, uri, data, callback)
+ end
+ if data.code == 'lowercase-global' then
+ solveLowercaseGlobal(lsp, uri, data, callback)
+ end
+ disableDiagnostic(lsp, uri, data, callback)
+end
+
+local function rangeContain(a, b)
+ if a.start.line > b.start.line then
+ return false
+ end
+ if a.start.character > b.start.character then
+ return false
+ end
+ if a['end'].line < b['end'].line then
+ return false
+ end
+ if a['end'].character < b['end'].character then
+ return false
+ end
+ return true
+end
+
+return function (lsp, uri, diagnostics, range)
+ local results = {}
+
+ for _, data in ipairs(diagnostics) do
+ if rangeContain(data.range, range) then
+ solveDiagnostic(lsp, uri, data, function (result)
+ results[#results+1] = result
+ end)
+ end
+ end
+
+ return results
+end
diff --git a/script/core/completion.lua b/script/core/completion.lua
new file mode 100644
index 00000000..756f136b
--- /dev/null
+++ b/script/core/completion.lua
@@ -0,0 +1,1079 @@
+local findSource = require 'core.find_source'
+local getFunctionHover = require 'core.hover.function'
+local getFunctionHoverAsLib = require 'core.hover.lib_function'
+local getFunctionHoverAsEmmy = require 'core.hover.emmy_function'
+local sourceMgr = require 'vm.source'
+local config = require 'config'
+local matchKey = require 'core.matchKey'
+local parser = require 'parser'
+local lang = require 'language'
+local snippet = require 'core.snippet'
+local State
+
+local CompletionItemKind = {
+ Text = 1,
+ Method = 2,
+ Function = 3,
+ Constructor = 4,
+ Field = 5,
+ Variable = 6,
+ Class = 7,
+ Interface = 8,
+ Module = 9,
+ Property = 10,
+ Unit = 11,
+ Value = 12,
+ Enum = 13,
+ Keyword = 14,
+ Snippet = 15,
+ Color = 16,
+ File = 17,
+ Reference = 18,
+ Folder = 19,
+ EnumMember = 20,
+ Constant = 21,
+ Struct = 22,
+ Event = 23,
+ Operator = 24,
+ TypeParameter = 25,
+}
+
+local KEYS = {'and', 'break', 'do', 'else', 'elseif', 'end', 'false', 'for', 'function', 'goto', 'if', 'in', 'local', 'nil', 'not', 'or', 'repeat', 'return', 'then', 'true', 'until', 'while'}
+local KEYMAP = {}
+for _, k in ipairs(KEYS) do
+ KEYMAP[k] = true
+end
+
+local EMMY_KEYWORD = {'class', 'type', 'alias', 'param', 'return', 'field', 'generic', 'vararg', 'language', 'see', 'overload'}
+
+local function getDucumentation(name, value)
+ if value:getType() == 'function' then
+ local lib = value:getLib()
+ local hover
+ if lib then
+ hover = getFunctionHoverAsLib(name, lib)
+ else
+ local emmy = value:getEmmy()
+ if emmy and emmy.type == 'emmy.functionType' then
+ hover = getFunctionHoverAsEmmy(name, emmy)
+ else
+ hover = getFunctionHover(name, value:getFunction())
+ end
+ end
+ if not hover then
+ return nil
+ end
+ local text = ([[
+```lua
+%s
+```
+%s
+```lua
+%s
+```
+%s
+]]):format(hover.label or '', hover.description or '', hover.enum or '', hover.doc or '')
+ return {
+ kind = 'markdown',
+ value = text,
+ }
+ end
+ local lib = value:getLib()
+ if lib then
+ return {
+ kind = 'markdown',
+ value = lib.description,
+ }
+ end
+ local comment = value:getComment()
+ if comment then
+ return {
+ kind = 'markdown',
+ value = comment,
+ }
+ end
+ return nil
+end
+
+local function getDetail(value)
+ local literal = value:getLiteral()
+ local tp = type(literal)
+ local detals = {}
+ if value:getType() ~= 'any' then
+ detals[#detals+1] = ('(%s)'):format(value:getType())
+ end
+ if tp == 'boolean' then
+ detals[#detals+1] = (' = %q'):format(literal)
+ elseif tp == 'string' then
+ detals[#detals+1] = (' = %q'):format(literal)
+ elseif tp == 'number' then
+ if math.type(literal) == 'integer' then
+ detals[#detals+1] = (' = %q'):format(literal)
+ else
+ local str = (' = %.16f'):format(literal)
+ local dot = str:find('.', 1, true)
+ local suffix = str:find('[0]+$', dot + 2)
+ if suffix then
+ detals[#detals+1] = str:sub(1, suffix - 1)
+ else
+ detals[#detals+1] = str
+ end
+ end
+ end
+ if value:getType() == 'function' then
+ ---@type emmyFunction
+ local func = value:getFunction()
+ local overLoads = func and func:getEmmyOverLoads()
+ if overLoads then
+ detals[#detals+1] = lang.script('HOVER_MULTI_PROTOTYPE', #overLoads + 1)
+ end
+ end
+ if #detals == 0 then
+ return nil
+ end
+ return table.concat(detals)
+end
+
+local function getKind(cata, value)
+ if value:getType() == 'function' then
+ local func = value:getFunction()
+ if func and func:getObject() then
+ return CompletionItemKind.Method
+ else
+ return CompletionItemKind.Function
+ end
+ end
+ if cata == 'field' then
+ local literal = value:getLiteral()
+ local tp = type(literal)
+ if tp == 'number' or tp == 'integer' or tp == 'string' then
+ return CompletionItemKind.Enum
+ end
+ end
+ return nil
+end
+
+local function buildSnipArgs(args, enums)
+ local t = {}
+ for i, arg in ipairs(args) do
+ local name = arg:match '^[^:]+'
+ local enum = enums and enums[name]
+ if enum and #enum > 0 then
+ t[i] = ('${%d|%s|}'):format(i, table.concat(enum, ','))
+ else
+ t[i] = ('${%d:%s}'):format(i, arg)
+ end
+ end
+ return table.concat(t, ', ')
+end
+
+local function getFunctionSnip(name, value, source)
+ if value:getType() ~= 'function' then
+ return
+ end
+ local lib = value:getLib()
+ local object = source:get 'object'
+ local hover
+ if lib then
+ hover = getFunctionHoverAsLib(name, lib, object)
+ else
+ local emmy = value:getEmmy()
+ if emmy and emmy.type == 'emmy.functionType' then
+ hover = getFunctionHoverAsEmmy(name, emmy, object)
+ else
+ hover = getFunctionHover(name, value:getFunction(), object)
+ end
+ end
+ if not hover then
+ return ('%s()'):format(name)
+ end
+ if not hover.args then
+ return ('%s()'):format(name)
+ end
+ return ('%s(%s)'):format(name, buildSnipArgs(hover.args, hover.rawEnum))
+end
+
+local function getValueData(cata, name, value, pos, source)
+ local data = {
+ documentation = getDucumentation(name, value),
+ detail = getDetail(value),
+ kind = getKind(cata, value),
+ snip = getFunctionSnip(name, value, source),
+ }
+ if cata == 'field' then
+ if not parser:grammar(name, 'Name') then
+ if source:get 'simple' and source:get 'simple' [1] ~= source then
+ data.textEdit = {
+ start = pos + 1,
+ finish = pos,
+ newText = ('[%q]'):format(name),
+ }
+ data.additionalTextEdits = {
+ {
+ start = pos,
+ finish = pos,
+ newText = '',
+ }
+ }
+ else
+ data.textEdit = {
+ start = pos + 1,
+ finish = pos,
+ newText = ('_ENV[%q]'):format(name),
+ }
+ data.additionalTextEdits = {
+ {
+ start = pos,
+ finish = pos,
+ newText = '',
+ }
+ }
+ end
+ end
+ end
+ return data
+end
+
+local function searchLocals(vm, source, word, callback, pos)
+ vm:eachSource(function (src)
+ local loc = src:bindLocal()
+ if not loc then
+ return
+ end
+
+ if src.start <= source.start
+ and loc:close() >= source.finish
+ and matchKey(word, loc:getName())
+ then
+ callback(loc:getName(), src, CompletionItemKind.Variable, getValueData('local', loc:getName(), loc:getValue(), pos, source))
+ end
+ end)
+end
+
+local function sortPairs(t)
+ local keys = {}
+ for k in pairs(t) do
+ keys[#keys+1] = k
+ end
+ table.sort(keys)
+ local i = 0
+ return function ()
+ i = i + 1
+ local k = keys[i]
+ return k, t[k]
+ end
+end
+
+local function searchFieldsByInfo(parent, word, source, map)
+ parent:eachInfo(function (info, src)
+ local k = info[1]
+ if src == source then
+ return
+ end
+ if map[k] then
+ return
+ end
+ if KEYMAP[k] then
+ return
+ end
+ if info.type ~= 'set child' and info.type ~= 'get child' then
+ return
+ end
+ if type(k) ~= 'string' then
+ return
+ end
+ local v = parent:getChild(k)
+ if not v then
+ return
+ end
+ if source:get 'object' and v:getType() ~= 'function' then
+ return
+ end
+ if matchKey(word, k) then
+ map[k] = v
+ end
+ end)
+end
+
+local function searchFieldsByChild(parent, word, source, map)
+ parent:eachChild(function (k, v)
+ if map[k] then
+ return
+ end
+ if KEYMAP[k] then
+ return
+ end
+ if not v:getLib() then
+ return
+ end
+ if type(k) ~= 'string' then
+ return
+ end
+ if source:get 'object' and v:getType() ~= 'function' then
+ return
+ end
+ if matchKey(word, k) then
+ map[k] = v
+ end
+ end)
+end
+
+---@param vm VM
+local function searchFields(vm, source, word, callback, pos)
+ local parent = source:get 'parent' or vm.env:getValue()
+ if not parent then
+ return
+ end
+ local map = {}
+ local current = parent
+ for _ = 1, 3 do
+ searchFieldsByInfo(current, word, source, map)
+ current = current:getMetaMethod('__index')
+ if not current then
+ break
+ end
+ end
+ searchFieldsByChild(parent, word, source, map)
+ for k, v in sortPairs(map) do
+ callback(k, nil, CompletionItemKind.Field, getValueData('field', k, v, pos, source))
+ end
+end
+
+local function searchIndex(vm, source, word, callback)
+ vm:eachSource(function (src)
+ if src:get 'table index' then
+ if matchKey(word, src[1]) then
+ callback(src[1], src, CompletionItemKind.Property)
+ end
+ end
+ end)
+end
+
+local function searchCloseGlobal(vm, start, finish, word, callback)
+ vm:eachSource(function (src)
+ if (src:get 'global' or src:bindLocal())
+ and src.start >= start
+ and src.finish <= finish
+ then
+ if matchKey(word, src[1]) then
+ callback(src[1], src, CompletionItemKind.Variable)
+ end
+ end
+ end)
+end
+
+local function searchParams(vm, source, func, word, callback)
+ if not func then
+ return
+ end
+ ---@type emmyFunction
+ local emmyParams = func:getEmmyParams()
+ if not emmyParams then
+ return
+ end
+ if #emmyParams > 1 then
+ if not func.args
+ or not func.args[1]
+ or func.args[1]:getSource() == source then
+ if matchKey(word, source and source[1] or '') then
+ local names = {}
+ for _, param in ipairs(emmyParams) do
+ local name = param:getName()
+ names[#names+1] = name
+ end
+ callback(table.concat(names, ', '), nil, CompletionItemKind.Snippet)
+ end
+ end
+ end
+ for _, param in ipairs(emmyParams) do
+ local name = param:getName()
+ if matchKey(word, name) then
+ callback(name, param:getSource(), CompletionItemKind.Interface)
+ end
+ end
+end
+
+local function searchKeyWords(vm, source, word, callback)
+ local snipType = config.config.completion.keywordSnippet
+ for _, key in ipairs(KEYS) do
+ if matchKey(word, key) then
+ if snippet.key[key] then
+ if snipType ~= 'Replace'
+ or key == 'local'
+ or key == 'return' then
+ callback(key, nil, CompletionItemKind.Keyword)
+ end
+ if snipType ~= 'Disable' then
+ for _, data in ipairs(snippet.key[key]) do
+ callback(data.label, nil, CompletionItemKind.Snippet, {
+ insertText = data.text,
+ })
+ end
+ end
+ else
+ callback(key, nil, CompletionItemKind.Keyword)
+ end
+ end
+ end
+end
+
+local function searchGlobals(vm, source, word, callback, pos)
+ local global = vm.env:getValue()
+ local map = {}
+ local current = global
+ for _ = 1, 3 do
+ searchFieldsByInfo(current, word, source, map)
+ current = current:getMetaMethod('__index')
+ if not current then
+ break
+ end
+ end
+ searchFieldsByChild(global, word, source, map)
+ for k, v in sortPairs(map) do
+ callback(k, nil, CompletionItemKind.Field, getValueData('field', k, v, pos, source))
+ end
+end
+
+local function searchAsGlobal(vm, source, word, callback, pos)
+ if word == '' then
+ return
+ end
+ searchLocals(vm, source, word, callback, pos)
+ searchFields(vm, source, word, callback, pos)
+ searchKeyWords(vm, source, word, callback)
+end
+
+local function searchAsKeyowrd(vm, source, word, callback, pos)
+ searchLocals(vm, source, word, callback, pos)
+ searchGlobals(vm, source, word, callback, pos)
+ searchKeyWords(vm, source, word, callback)
+end
+
+local function searchAsSuffix(vm, source, word, callback, pos)
+ searchFields(vm, source, word, callback, pos)
+end
+
+local function searchAsIndex(vm, source, word, callback, pos)
+ searchLocals(vm, source, word, callback, pos)
+ searchIndex(vm, source, word, callback)
+ searchFields(vm, source, word, callback, pos)
+end
+
+local function searchAsLocal(vm, source, word, callback)
+ local loc = source:bindLocal()
+ if not loc then
+ return
+ end
+ local close = loc:close()
+ -- 因为闭包的关系落在局部变量finish到close范围内的全局变量一定能访问到该局部变量
+ searchCloseGlobal(vm, source.finish, close, word, callback)
+ -- 特殊支持 local function
+ if matchKey(word, 'function') then
+ callback('function', nil, CompletionItemKind.Keyword)
+ -- TODO 需要有更优美的实现方式
+ local data = snippet.key['function'][1]
+ callback(data.label, nil, CompletionItemKind.Snippet, {
+ insertText = data.text,
+ })
+ end
+end
+
+local function searchAsArg(vm, source, word, callback)
+ searchParams(vm, source, source:get 'arg', word, callback)
+
+ local loc = source:bindLocal()
+ if loc then
+ local close = loc:close()
+ -- 因为闭包的关系落在局部变量finish到close范围内的全局变量一定能访问到该局部变量
+ searchCloseGlobal(vm, source.finish, close, word, callback)
+ return
+ end
+end
+
+local function searchFunction(vm, source, word, pos, callback)
+ if pos >= source.argStart and pos <= source.argFinish then
+ searchParams(vm, nil, source:bindFunction():getFunction(), word, callback)
+ searchCloseGlobal(vm, source.argFinish, source.finish, word, callback)
+ end
+end
+
+local function searchEmmyKeyword(vm, source, word, callback)
+ for _, kw in ipairs(EMMY_KEYWORD) do
+ if matchKey(word, kw) then
+ callback(kw, nil, CompletionItemKind.Keyword)
+ end
+ end
+end
+
+local function searchEmmyClass(vm, source, word, callback)
+ local classes = {}
+ vm.emmyMgr:eachClass(function (class)
+ if class.type == 'emmy.class' or class.type == 'emmy.alias' then
+ if matchKey(word, class:getName()) then
+ classes[#classes+1] = class
+ end
+ end
+ end)
+ table.sort(classes, function (a, b)
+ return a:getName() < b:getName()
+ end)
+ for _, class in ipairs(classes) do
+ callback(class:getName(), class:getSource(), CompletionItemKind.Class)
+ end
+end
+
+local function searchEmmyFunctionParam(vm, source, word, callback)
+ local func = source:get 'emmy function'
+ if not func.args then
+ return
+ end
+ if #func.args > 1 and matchKey(word, func.args[1].name) then
+ local list = {}
+ local args = {}
+ for i, arg in ipairs(func.args) do
+ if func:getObject() and i == 1 then
+ goto NEXT
+ end
+ args[#args+1] = arg.name
+ if #list == 0 then
+ list[#list+1] = ('%s any'):format(arg.name)
+ else
+ list[#list+1] = ('---@param %s any'):format(arg.name)
+ end
+ :: NEXT ::
+ end
+ callback(('%s'):format(table.concat(args, ', ')), nil, CompletionItemKind.Snippet, {
+ insertText = table.concat(list, '\n')
+ })
+ end
+ for i, arg in ipairs(func.args) do
+ if func:getObject() and i == 1 then
+ goto NEXT
+ end
+ if matchKey(word, arg.name) then
+ callback(arg.name, nil, CompletionItemKind.Interface)
+ end
+ :: NEXT ::
+ end
+end
+
+local function searchSource(vm, source, word, callback, pos)
+ if source.type == 'keyword' then
+ searchAsKeyowrd(vm, source, word, callback, pos)
+ return
+ end
+ if source:get 'table index' then
+ searchAsIndex(vm, source, word, callback, pos)
+ return
+ end
+ if source:get 'arg' then
+ searchAsArg(vm, source, word, callback)
+ return
+ end
+ if source:get 'global' then
+ searchAsGlobal(vm, source, word, callback, pos)
+ return
+ end
+ if source:action() == 'local' then
+ searchAsLocal(vm, source, word, callback)
+ return
+ end
+ if source:bindLocal() then
+ searchAsGlobal(vm, source, word, callback, pos)
+ return
+ end
+ if source:get 'simple'
+ and (source.type == 'name' or source.type == '.' or source.type == ':') then
+ searchAsSuffix(vm, source, word, callback, pos)
+ return
+ end
+ if source:bindFunction() then
+ searchFunction(vm, source, word, pos, callback)
+ return
+ end
+ if source.type == 'emmyIncomplete' then
+ searchEmmyKeyword(vm, source, word, callback)
+ State.ignoreText = true
+ return
+ end
+ if source:get 'emmy class' then
+ searchEmmyClass(vm, source, word, callback)
+ State.ignoreText = true
+ return
+ end
+ if source:get 'emmy function' then
+ searchEmmyFunctionParam(vm, source, word, callback)
+ State.ignoreText = true
+ return
+ end
+end
+
+local function buildTextEdit(start, finish, str, quo)
+ local text, lquo, rquo, label, filterText
+ if quo == nil then
+ local text = str:gsub('\r', '\\r'):gsub('\n', '\\n'):gsub('"', '\\"')
+ return {
+ label = '"' .. text .. '"'
+ }
+ end
+ if quo == '"' then
+ label = str
+ filterText = str
+ text = str:gsub('\r', '\\r'):gsub('\n', '\\n'):gsub('"', '\\"')
+ lquo = quo
+ rquo = quo
+ elseif quo == "'" then
+ label = str
+ filterText = str
+ text = str:gsub('\r', '\\r'):gsub('\n', '\\n'):gsub("'", "\\'")
+ lquo = quo
+ rquo = quo
+ else
+ label = str
+ filterText = str
+ lquo = quo
+ rquo = ']' .. lquo:sub(2, -2) .. ']'
+ while str:find(rquo, 1, true) do
+ lquo = '[=' .. quo:sub(2)
+ rquo = ']' .. lquo:sub(2, -2) .. ']'
+ end
+ text = str
+ end
+ return {
+ label = label,
+ filterText = filterText,
+ textEdit = {
+ start = start + #quo,
+ finish = finish - #quo,
+ newText = text,
+ },
+ additionalTextEdits = {
+ {
+ start = start,
+ finish = start + #quo - 1,
+ newText = lquo,
+ },
+ {
+ start = finish - #quo + 1,
+ finish = finish,
+ newText = rquo,
+ },
+ }
+ }
+end
+
+local function searchInRequire(vm, source, callback)
+ if not vm.lsp or not vm.lsp.workspace then
+ return
+ end
+ if source.type ~= 'string' then
+ return
+ end
+ local list, map = vm.lsp.workspace:matchPath(vm.uri, source[1])
+ if not list then
+ return
+ end
+ for _, str in ipairs(list) do
+ local data = buildTextEdit(source.start, source.finish, str, source[2])
+ data.documentation = map[str]
+ callback(str, nil, CompletionItemKind.Reference, data)
+ end
+end
+
+local function searchEnumAsLib(vm, source, word, callback, pos, args, lib)
+ local select = #args + 1
+ for i, arg in ipairs(args) do
+ if arg.start <= pos and arg.finish >= pos then
+ select = i
+ break
+ end
+ end
+
+ -- 根据参数位置找枚举值
+ if lib.args and lib.enums then
+ local arg = lib.args[select]
+ local name = arg and arg.name
+ for _, enum in ipairs(lib.enums) do
+ if enum.name and enum.name == name and enum.enum then
+ if matchKey(word, enum.enum) then
+ local strSource = parser:parse(tostring(enum.enum), 'String')
+ if strSource then
+ if source.type == 'string' then
+ local data = buildTextEdit(source.start, source.finish, strSource[1], source[2])
+ data.documentation = enum.description
+ callback(enum.enum, nil, CompletionItemKind.EnumMember, data)
+ else
+ callback(enum.enum, nil, CompletionItemKind.EnumMember, {
+ documentation = enum.description
+ })
+ end
+ end
+ else
+ callback(enum.enum, nil, CompletionItemKind.EnumMember, {
+ documentation = enum.description
+ })
+ end
+ end
+ end
+ end
+
+ -- 搜索特殊函数
+ if lib.special == 'require' then
+ if select == 1 then
+ searchInRequire(vm, source, callback)
+ end
+ end
+end
+
+local function buildEmmyEnumComment(enum, data)
+ if not enum.comment then
+ return data
+ end
+ if not data then
+ data = {}
+ end
+ data.documentation = tostring(enum.comment)
+ return data
+end
+
+local function searchEnumAsEmmyParams(vm, source, word, callback, pos, args, func)
+ local select = #args + 1
+ for i, arg in ipairs(args) do
+ if arg.start <= pos and arg.finish >= pos then
+ select = i
+ break
+ end
+ end
+
+ local param = func:findEmmyParamByIndex(select)
+ if not param then
+ return
+ end
+
+ param:eachEnum(function (enum)
+ local str = enum[1]
+ if matchKey(word, str) then
+ local strSource = parser:parse(tostring(str), 'String')
+ if strSource then
+ if source.type == 'string' then
+ local data = buildTextEdit(source.start, source.finish, strSource[1], source[2])
+ callback(str, nil, CompletionItemKind.EnumMember, buildEmmyEnumComment(enum, data))
+ else
+ callback(str, nil, CompletionItemKind.EnumMember, buildEmmyEnumComment(enum))
+ end
+ else
+ callback(str, nil, CompletionItemKind.EnumMember, buildEmmyEnumComment(enum))
+ end
+ end
+ end)
+
+ local option = param:getOption()
+ if option and option.special == 'require:1' then
+ searchInRequire(vm, source, callback)
+ end
+end
+
+local function getSelect(args, pos)
+ if not args then
+ return 1
+ end
+ for i, arg in ipairs(args) do
+ if arg.start <= pos and arg.finish >= pos - 1 then
+ return i
+ end
+ end
+ return #args + 1
+end
+
+local function isInFunctionOrTable(call, pos)
+ local args = call:bindCall()
+ if not args then
+ return false
+ end
+ local select = getSelect(args, pos)
+ local arg = args[select]
+ if not arg then
+ return false
+ end
+ if arg.type == 'function' or arg.type == 'table' then
+ return true
+ end
+ return false
+end
+
+local function searchCallArg(vm, source, word, callback, pos)
+ local results = {}
+ vm:eachSource(function (src)
+ if src.type == 'call'
+ and src.start <= pos
+ and src.finish >= pos
+ then
+ results[#results+1] = src
+ end
+ end)
+ if #results == 0 then
+ return nil
+ end
+ -- 可能处于 'func1(func2(' 的嵌套中,将最近的call放到最前面
+ table.sort(results, function (a, b)
+ return a.start > b.start
+ end)
+ local call = results[1]
+ if isInFunctionOrTable(call, pos) then
+ return
+ end
+
+ local args = call:bindCall()
+ if not args then
+ return
+ end
+
+ local value = call:findCallFunction()
+ if not value then
+ return
+ end
+
+ local lib = value:getLib()
+ if lib then
+ searchEnumAsLib(vm, source, word, callback, pos, args, lib)
+ return
+ end
+
+ ---@type emmyFunction
+ local func = value:getFunction()
+ if func then
+ searchEnumAsEmmyParams(vm, source, word, callback, pos, args, func)
+ return
+ end
+end
+
+local function searchAllWords(vm, source, word, callback, pos)
+ if word == '' then
+ return
+ end
+ if source.type == 'string' then
+ return
+ end
+ vm:eachSource(function (src)
+ if src.type == 'name'
+ and not (src.start <= pos and src.finish >= pos)
+ and matchKey(word, src[1])
+ then
+ callback(src[1], src, CompletionItemKind.Text)
+ end
+ end)
+end
+
+local function searchSpecialHashSign(vm, pos, text, callback)
+ -- 尝试 XXX[#XXX+1]
+ -- 1. 搜索 []
+ local index
+ vm:eachSource(function (src)
+ if src.type == 'index'
+ and src.start <= pos
+ and src.finish >= pos
+ then
+ index = src
+ return true
+ end
+ end)
+ if not index then
+ return nil
+ end
+ -- 2. [] 内部只能有一个 #
+ local inside = index[1]
+ if not inside then
+ return nil
+ end
+ if inside.op ~= '#' then
+ return nil
+ end
+ -- 3. [] 左侧必须是 simple ,且index 是 simple 的最后一项
+ local simple = index:get 'simple'
+ if not simple then
+ return nil
+ end
+ if simple[#simple] ~= index then
+ return nil
+ end
+ local chars = text:sub(simple.start, simple[#simple-1].finish)
+ -- 4. 创建代码片段
+ if simple:get 'as action' then
+ local label = chars .. '+1'
+ callback(label, nil, CompletionItemKind.Snippet, {
+ textEdit = {
+ start = inside.start + 1,
+ finish = index.finish,
+ newText = ('%s] = '):format(label),
+ },
+ })
+ else
+ local label = chars
+ callback(label, nil, CompletionItemKind.Snippet, {
+ textEdit = {
+ start = inside.start + 1,
+ finish = index.finish,
+ newText = ('%s]'):format(label),
+ },
+ })
+ end
+end
+
+local function searchSpecial(vm, source, word, callback, pos, text)
+ searchSpecialHashSign(vm, pos, text, callback)
+end
+
+local function makeList(source, pos, word)
+ local list = {}
+ local mark = {}
+ return function (name, src, kind, data)
+ if src == source then
+ return
+ end
+ if word == name then
+ if src and src.start <= pos and src.finish >= pos then
+ return
+ end
+ end
+ if mark[name] then
+ return
+ end
+ mark[name] = true
+ if not data then
+ data = {}
+ end
+ if not data.label then
+ data.label = name
+ end
+ if not data.kind then
+ data.kind = kind
+ end
+ list[#list+1] = data
+ if data.snip then
+ local snipType = config.config.completion.callSnippet
+ if snipType ~= 'Disable' then
+ local snipData = table.deepCopy(data)
+ snipData.insertText = data.snip
+ snipData.kind = CompletionItemKind.Snippet
+ snipData.label = snipData.label .. '()'
+ snipData.snip = nil
+ if snipType == 'Both' then
+ list[#list+1] = snipData
+ elseif snipType == 'Replace' then
+ list[#list] = snipData
+ end
+ end
+ data.snip = nil
+ end
+ end, list
+end
+
+local function keywordSource(vm, word, pos)
+ if not KEYMAP[word] then
+ return nil
+ end
+ return vm:instantSource {
+ type = 'keyword',
+ start = pos,
+ finish = pos + #word - 1,
+ [1] = word,
+ }
+end
+
+local function findStartPos(pos, buf)
+ local res = nil
+ for i = pos, 1, -1 do
+ local c = buf:sub(i, i)
+ if c:find '[%w_]' then
+ res = i
+ else
+ break
+ end
+ end
+ if not res then
+ for i = pos, 1, -1 do
+ local c = buf:sub(i, i)
+ if c == '.' or c == ':' or c == '|' or c == '(' then
+ res = i
+ break
+ elseif c == '#' or c == '@' then
+ res = i + 1
+ break
+ elseif c:find '[%s%c]' then
+ else
+ break
+ end
+ end
+ end
+ if not res then
+ return pos
+ end
+ return res
+end
+
+local function findWord(position, text)
+ local word = text
+ for i = position, 1, -1 do
+ local c = text:sub(i, i)
+ if not c:find '[%w_]' then
+ word = text:sub(i+1, position)
+ break
+ end
+ end
+ return word:match('^([%w_]*)')
+end
+
+local function getSource(vm, pos, text, filter)
+ local word = findWord(pos, text)
+ local source = keywordSource(vm, word, pos)
+ if source then
+ return source, pos, word
+ end
+ source = findSource(vm, pos, filter)
+ if source then
+ return source, pos, word
+ end
+ pos = findStartPos(pos, text)
+ source = findSource(vm, pos, filter) or findSource(vm, pos-1, filter)
+ return source, pos, word
+end
+
+return function (vm, text, pos, oldText)
+ local filter = {
+ ['name'] = true,
+ ['string'] = true,
+ ['.'] = true,
+ [':'] = true,
+ ['emmyName'] = true,
+ ['emmyIncomplete'] = true,
+ ['call'] = true,
+ ['function'] = true,
+ ['localfunction'] = true,
+ }
+ local source, pos, word = getSource(vm, pos, text, filter)
+ if not source then
+ return nil
+ end
+ if oldText then
+ local oldWord = oldText:sub(source.start, source.finish)
+ if word:sub(1, #oldWord) ~= oldWord then
+ return nil
+ end
+ end
+ State = {}
+ local callback, list = makeList(source, pos, word)
+ searchSpecial(vm, source, word, callback, pos, text)
+ searchCallArg(vm, source, word, callback, pos)
+ searchSource(vm, source, word, callback, pos)
+ if not oldText or #list > 0 then
+ if not State.ignoreText then
+ searchAllWords(vm, source, word, callback, pos)
+ end
+ end
+
+ if #list == 0 then
+ return nil
+ end
+
+ return list
+end
diff --git a/script/core/definition.lua b/script/core/definition.lua
new file mode 100644
index 00000000..8680a29b
--- /dev/null
+++ b/script/core/definition.lua
@@ -0,0 +1,296 @@
+local findSource = require 'core.find_source'
+local Mode
+
+local function parseValueSimily(callback, vm, source)
+ local key = source[1]
+ if not key then
+ return nil
+ end
+ vm:eachSource(function (other)
+ if other == source then
+ goto CONTINUE
+ end
+ if other[1] == key
+ and not other:bindLocal()
+ and other:bindValue()
+ and source:bindValue() ~= other:bindValue()
+ then
+ if Mode == 'definition' then
+ if other:action() == 'set' then
+ callback(other)
+ end
+ elseif Mode == 'reference' then
+ if other:action() == 'set' or other:action() == 'get' then
+ callback(other)
+ end
+ end
+ end
+ :: CONTINUE ::
+ end)
+end
+
+local function parseLocal(callback, vm, source)
+ ---@type Local
+ local loc = source:bindLocal()
+ callback(loc:getSource())
+ loc:eachInfo(function (info, src)
+ if Mode == 'definition' then
+ if info.type == 'set' or info.type == 'local' then
+ if vm.uri == src:getUri() then
+ if source.id >= src.id then
+ callback(src)
+ end
+ end
+ end
+ elseif Mode == 'reference' then
+ if info.type == 'set' or info.type == 'local' or info.type == 'return' or info.type == 'get' then
+ callback(src)
+ end
+ end
+ end)
+end
+
+local function parseValueByValue(callback, vm, source, value)
+ if not source then
+ return
+ end
+ local mark = { [vm] = true }
+ local list = {}
+ for _ = 1, 5 do
+ value:eachInfo(function (info, src)
+ if Mode == 'definition' then
+ if info.type == 'local' then
+ if vm.uri == src:getUri() then
+ if source.id >= src.id then
+ callback(src)
+ end
+ end
+ end
+ if info.type == 'set' then
+ if vm.uri == src:getUri() then
+ if source.id >= src.id then
+ callback(src)
+ end
+ else
+ callback(src)
+ end
+ end
+ if info.type == 'return' then
+ if (src.type ~= 'simple' or src[#src].type == 'call')
+ and src.type ~= 'name'
+ then
+ callback(src)
+ end
+ if vm.lsp then
+ local destVM = vm.lsp:getVM(src:getUri())
+ if destVM and not mark[destVM] then
+ mark[destVM] = true
+ list[#list+1] = { destVM, src }
+ end
+ end
+ end
+ elseif Mode == 'reference' then
+ if info.type == 'set' or info.type == 'local' or info.type == 'return' or info.type == 'get' then
+ callback(src)
+ end
+ end
+ end)
+ local nextData = table.remove(list, 1)
+ if nextData then
+ vm, source = nextData[1], nextData[2]
+ end
+ end
+end
+
+local function parseValue(callback, vm, source)
+ local value = source:bindValue()
+ local isGlobal
+ if value then
+ isGlobal = value:isGlobal()
+ parseValueByValue(callback, vm, source, value)
+ local emmy = value:getEmmy()
+ if emmy and emmy.type == 'emmy.type' then
+ ---@type EmmyType
+ local emmyType = emmy
+ emmyType:eachClass(function (class)
+ if class and class:getValue() then
+ local emmyVM = vm
+ if vm.lsp then
+ local destVM = vm.lsp:getVM(class:getSource():getUri())
+ if destVM then
+ emmyVM = destVM
+ end
+ end
+ parseValueByValue(callback, emmyVM, class:getValue():getSource(), class:getValue())
+ end
+ end)
+ end
+ end
+ local parent = source:get 'parent'
+ for _ = 1, 3 do
+ if parent then
+ local ok = parent:eachInfo(function (info, src)
+ if Mode == 'definition' then
+ if info.type == 'set child' and info[1] == source[1] then
+ callback(src)
+ return true
+ end
+ elseif Mode == 'reference' then
+ if (info.type == 'set child' or info.type == 'get child') and info[1] == source[1] then
+ callback(src)
+ return true
+ end
+ end
+ end)
+ if ok then
+ break
+ end
+ parent = parent:getMetaMethod('__index')
+ end
+ end
+ return isGlobal
+end
+
+local function parseLabel(callback, vm, label)
+ label:eachInfo(function (info, src)
+ if Mode == 'definition' then
+ if info.type == 'set' then
+ callback(src)
+ end
+ elseif Mode == 'reference' then
+ if info.type == 'set' or info.type == 'get' then
+ callback(src)
+ end
+ end
+ end)
+end
+
+local function jumpUri(callback, vm, source)
+ local uri = source:get 'target uri'
+ callback {
+ start = 0,
+ finish = 0,
+ uri = uri
+ }
+end
+
+local function parseClass(callback, vm, source)
+ local className = source:get 'emmy class'
+ vm.emmyMgr:eachClass(className, function (class)
+ if Mode == 'definition' then
+ if class.type == 'emmy.class' or class.type == 'emmy.alias' then
+ local src = class:getSource()
+ callback(src)
+ end
+ elseif Mode == 'reference' then
+ if class.type == 'emmy.class' or class.type == 'emmy.alias' or class.type == 'emmy.typeUnit' then
+ local src = class:getSource()
+ callback(src)
+ end
+ end
+ end)
+end
+
+local function parseSee(callback, vm, source)
+ local see = source:get 'emmy see'
+ local className = see[1][1]
+ local childName = see[2][1]
+ vm.emmyMgr:eachClass(className, function (class)
+ ---@type value
+ local value = class:getValue()
+ local child = value:getChild(childName)
+ parseValueByValue(callback, vm, source, child)
+ end)
+end
+
+local function parseFunction(callback, vm, source)
+ if Mode == 'definition' then
+ callback(source:bindFunction():getSource())
+ source:bindFunction():eachInfo(function (info, src)
+ if info.type == 'set' or info.type == 'local' then
+ if vm.uri == src:getUri() then
+ if source.id >= src.id then
+ callback(src)
+ end
+ else
+ callback(src)
+ end
+ end
+ end)
+ elseif Mode == 'reference' then
+ callback(source:bindFunction():getSource())
+ source:bindFunction():eachInfo(function (info, src)
+ if info.type == 'set' or info.type == 'local' or info.type == 'get' then
+ callback(src)
+ end
+ end)
+ end
+end
+
+local function makeList(source)
+ local list = {}
+ local mark = {}
+ return list, function (src)
+ if mark[src] then
+ return
+ end
+ mark[src] = true
+ list[#list+1] = {
+ src.start,
+ src.finish,
+ src.uri
+ }
+ end
+end
+
+return function (vm, pos, mode)
+ local filter = {
+ ['name'] = true,
+ ['string'] = true,
+ ['number'] = true,
+ ['boolean'] = true,
+ ['label'] = true,
+ ['goto'] = true,
+ ['function'] = true,
+ ['...'] = true,
+ ['emmyName'] = true,
+ ['emmyIncomplete'] = true,
+ }
+ local source = findSource(vm, pos, filter)
+ if not source then
+ return nil
+ end
+ Mode = mode
+ local list, callback = makeList(source)
+ local isGlobal
+ if source:bindLocal() then
+ parseLocal(callback, vm, source)
+ end
+ if source:bindValue() then
+ isGlobal = parseValue(callback, vm, source)
+ end
+ if source:bindLabel() then
+ parseLabel(callback, vm, source:bindLabel())
+ end
+ if source:bindFunction() then
+ parseFunction(callback, vm, source)
+ end
+ if source:get 'target uri' then
+ jumpUri(callback, vm, source)
+ end
+ if source:get 'in index' then
+ isGlobal = parseValue(callback, vm, source)
+ end
+ if source:get 'emmy class' then
+ parseClass(callback, vm, source)
+ end
+ if source:get 'emmy see' then
+ parseSee(callback, vm, source)
+ end
+
+ if #list == 0 then
+ parseValueSimily(callback, vm, source)
+ end
+
+ return list, isGlobal
+end
diff --git a/script/core/diagnostics.lua b/script/core/diagnostics.lua
new file mode 100644
index 00000000..3b11b818
--- /dev/null
+++ b/script/core/diagnostics.lua
@@ -0,0 +1,1042 @@
+local lang = require 'language'
+local config = require 'config'
+local library = require 'core.library'
+local buildGlobal = require 'vm.global'
+local DiagnosticSeverity = require 'constant.DiagnosticSeverity'
+local DiagnosticDefaultSeverity = require 'constant.DiagnosticDefaultSeverity'
+local DiagnosticTag = require 'constant.DiagnosticTag'
+
+local mt = {}
+mt.__index = mt
+
+local function isContainPos(obj, start, finish)
+ if obj.start <= start and obj.finish >= finish then
+ return true
+ end
+ return false
+end
+
+function mt:searchUnusedLocals(callback)
+ self.vm:eachSource(function (source)
+ local loc = source:bindLocal()
+ if not loc then
+ return
+ end
+ if loc:get 'emmy arg' then
+ return
+ end
+ local name = loc:getName()
+ if name == '_' or name == '_ENV' or name == '' then
+ return
+ end
+ if source:action() ~= 'local' then
+ return
+ end
+ if loc:get 'hide' then
+ return
+ end
+ local used = loc:eachInfo(function (info)
+ if info.type == 'get' then
+ return true
+ end
+ end)
+ if not used then
+ callback(source.start, source.finish, name)
+ end
+ end)
+end
+
+function mt:searchUnusedFunctions(callback)
+ self.vm:eachSource(function (source)
+ local loc = source:bindLocal()
+ if not loc then
+ return
+ end
+ if loc:get 'emmy arg' then
+ return
+ end
+ if source:action() ~= 'local' then
+ return
+ end
+ if loc:get 'hide' then
+ return
+ end
+ local used = loc:eachInfo(function (info)
+ if info.type == 'get' then
+ return true
+ end
+ end)
+ if used then
+ return
+ end
+ loc:eachInfo(function (info, src)
+ if info.type == 'set' or info.type == 'local' then
+ local v = src:bindValue()
+ local func = v and v:getFunction()
+ if func and func:getSource().uri == self.vm.uri then
+ callback(func:getSource().start, func:getSource().finish)
+ end
+ end
+ end)
+ end)
+end
+
+function mt:searchUndefinedGlobal(callback)
+ local definedGlobal = {}
+ for name in pairs(config.config.diagnostics.globals) do
+ definedGlobal[name] = true
+ end
+ local envValue = buildGlobal(self.vm.lsp)
+ envValue:eachInfo(function (info)
+ if info.type == 'set child' then
+ local name = info[1]
+ definedGlobal[name] = true
+ end
+ end)
+ self.vm:eachSource(function (source)
+ if not source:get 'global' then
+ return
+ end
+ local name = source:getName()
+ if name == '' then
+ return
+ end
+ local parent = source:get 'parent'
+ if not parent then
+ return
+ end
+ if not parent:get 'ENV' and not source:get 'in index' then
+ return
+ end
+ if definedGlobal[name] then
+ return
+ end
+ if type(name) ~= 'string' then
+ return
+ end
+ callback(source.start, source.finish, name)
+ end)
+end
+
+function mt:searchUnusedLabel(callback)
+ self.vm:eachSource(function (source)
+ local label = source:bindLabel()
+ if not label then
+ return
+ end
+ if source:action() ~= 'set' then
+ return
+ end
+ local used = label:eachInfo(function (info)
+ if info.type == 'get' then
+ return true
+ end
+ end)
+ if not used then
+ callback(source.start, source.finish, label:getName())
+ end
+ end)
+end
+
+function mt:searchUnusedVararg(callback)
+ self.vm:eachSource(function (source)
+ local value = source:bindFunction()
+ if not value then
+ return
+ end
+ local func = value:getFunction()
+ if not func then
+ return
+ end
+ if func._dotsSource and not func._dotsLoad then
+ callback(func._dotsSource.start, func._dotsSource.finish)
+ end
+ end)
+end
+
+local function isInString(vm, start, finish)
+ return vm:eachSource(function (source)
+ if source.type == 'string' and isContainPos(source, start, finish) then
+ return true
+ end
+ end)
+end
+
+function mt:searchSpaces(callback)
+ local vm = self.vm
+ local lines = self.lines
+ for i = 1, #lines do
+ local line = lines:line(i)
+
+ if line:find '^[ \t]+$' then
+ local start, finish = lines:range(i)
+ if isInString(vm, start, finish) then
+ goto NEXT_LINE
+ end
+ callback(start, finish, lang.script.DIAG_LINE_ONLY_SPACE)
+ goto NEXT_LINE
+ end
+
+ local pos = line:find '[ \t]+$'
+ if pos then
+ local start, finish = lines:range(i)
+ start = start + pos - 1
+ if isInString(vm, start, finish) then
+ goto NEXT_LINE
+ end
+ callback(start, finish, lang.script.DIAG_LINE_POST_SPACE)
+ goto NEXT_LINE
+ end
+
+ ::NEXT_LINE::
+ end
+end
+
+function mt:searchRedefinition(callback)
+ local used = {}
+ local uri = self.uri
+ self.vm:eachSource(function (source)
+ local loc = source:bindLocal()
+ if not loc then
+ return
+ end
+ local shadow = loc:shadow()
+ if not shadow then
+ return
+ end
+ if used[shadow] then
+ return
+ end
+ used[shadow] = true
+ if loc:get 'hide' then
+ return
+ end
+ local name = loc:getName()
+ if name == '_' or name == '_ENV' or name == '' then
+ return
+ end
+ local related = {}
+ for i = 1, #shadow do
+ related[i] = {
+ start = shadow[i]:getSource().start,
+ finish = shadow[i]:getSource().finish,
+ uri = uri,
+ }
+ end
+ for i = 2, #shadow do
+ callback(shadow[i]:getSource().start, shadow[i]:getSource().finish, name, related)
+ end
+ end)
+end
+
+function mt:searchNewLineCall(callback)
+ local lines = self.lines
+ self.vm:eachSource(function (source)
+ if source.type ~= 'simple' then
+ return
+ end
+ for i = 1, #source - 1 do
+ local callSource = source[i]
+ local funcSource = source[i-1]
+ if callSource.type ~= 'call' then
+ goto CONTINUE
+ end
+ local callLine = lines:rowcol(callSource.start)
+ local funcLine = lines:rowcol(funcSource.finish)
+ if callLine > funcLine then
+ callback(callSource.start, callSource.finish)
+ end
+ :: CONTINUE ::
+ end
+ end)
+end
+
+function mt:searchNewFieldCall(callback)
+ local lines = self.lines
+ self.vm:eachSource(function (source)
+ if source.type ~= 'table' then
+ return
+ end
+ for i = 1, #source do
+ local field = source[i]
+ if field.type == 'simple' then
+ local callSource = field[#field]
+ local funcSource = field[#field-1]
+ local callLine = lines:rowcol(callSource.start)
+ local funcLine = lines:rowcol(funcSource.finish)
+ if callLine > funcLine then
+ callback(funcSource.start, callSource.finish
+ , lines.buf:sub(funcSource.start, funcSource.finish)
+ , lines.buf:sub(callSource.start, callSource.finish)
+ )
+ end
+ end
+ end
+ end)
+end
+
+function mt:searchRedundantParameters(callback)
+ self.vm:eachSource(function (source)
+ local args = source:bindCall()
+ if not args then
+ return
+ end
+
+ -- 回调函数不检查
+ local simple = source:get 'simple'
+ if simple and simple[2] == source then
+ local loc = simple[1]:bindLocal()
+ if loc then
+ local source = loc:getSource()
+ if source:get 'arg' then
+ return
+ end
+ end
+ end
+
+ local value = source:findCallFunction()
+ if not value then
+ return
+ end
+
+ local func = value:getFunction()
+ -- 参数中有 ... ,不用再检查了
+ if func:hasDots() then
+ return
+ end
+ local max = #func.args
+ local passed = #args
+ -- function m.open() end
+ -- m:open()
+ -- 这种写法不算错
+ if passed == 1 and source:get 'has object' then
+ return
+ end
+ for i = max + 1, passed do
+ local extra = args[i]
+ callback(extra.start, extra.finish, max, passed)
+ end
+ end)
+end
+
+local opMap = {
+ ['+'] = true,
+ ['-'] = true,
+ ['*'] = true,
+ ['/'] = true,
+ ['//'] = true,
+ ['^'] = true,
+ ['<<'] = true,
+ ['>>'] = true,
+ ['&'] = true,
+ ['|'] = true,
+ ['~'] = true,
+ ['..'] = true,
+}
+
+local literalMap = {
+ ['number'] = true,
+ ['boolean'] = true,
+ ['string'] = true,
+ ['table'] = true,
+}
+
+function mt:searchAmbiguity1(callback)
+ self.vm:eachSource(function (source)
+ if source.op ~= 'or' then
+ return
+ end
+ local first = source[1]
+ local second = source[2]
+ -- a + (b or 0) --> (a + b) or 0
+ do
+ if opMap[first.op]
+ and first.type ~= 'unary'
+ and not second.op
+ and literalMap[second.type]
+ and not first.brackets
+ then
+ callback(source.start, source.finish, first.start, first.finish)
+ end
+ end
+ -- (a or 0) + c --> a or (0 + c)
+ do
+ if opMap[second.op]
+ and second.type ~= 'unary'
+ and not first.op
+ and literalMap[second[1].type]
+ and not second.brackets
+ then
+ callback(source.start, source.finish, second.start, second.finish)
+ end
+ end
+ end)
+end
+
+function mt:searchLowercaseGlobal(callback)
+ local definedGlobal = {}
+ for name in pairs(config.config.diagnostics.globals) do
+ definedGlobal[name] = true
+ end
+ for name in pairs(library.global) do
+ definedGlobal[name] = true
+ end
+ self.vm:eachSource(function (source)
+ if source.type == 'name'
+ and source:get 'parent'
+ and not source:get 'simple'
+ and not source:get 'table index'
+ and source:action() == 'set'
+ then
+ local name = source[1]
+ if definedGlobal[name] then
+ return
+ end
+ local first = name:match '%w'
+ if not first then
+ return
+ end
+ if first:match '%l' then
+ callback(source.start, source.finish)
+ end
+ end
+ end)
+end
+
+function mt:searchDuplicateIndex(callback)
+ self.vm:eachSource(function (source)
+ if source.type ~= 'table' then
+ return
+ end
+ local mark = {}
+ for _, obj in ipairs(source) do
+ if obj.type == 'pair' then
+ local key = obj[1]
+ local name
+ if key.index then
+ if key.type == 'string' then
+ name = key[1]
+ end
+ elseif key.type == 'name' then
+ name = key[1]
+ end
+ if name then
+ if mark[name] then
+ mark[name][#mark[name]+1] = obj
+ else
+ mark[name] = { obj }
+ end
+ end
+ end
+ end
+ for name, defs in pairs(mark) do
+ if #defs > 1 then
+ local related = {}
+ for i = 1, #defs do
+ related[i] = {
+ start = defs[i][1].start,
+ finish = defs[i][2].finish,
+ uri = self.uri,
+ }
+ end
+ for i = 1, #defs - 1 do
+ callback(defs[i][1].start, defs[i][2].finish, name, related, 'unused')
+ end
+ for i = #defs, #defs do
+ callback(defs[i][1].start, defs[i][1].finish, name, related, 'duplicate')
+ end
+ end
+ end
+ end)
+end
+
+function mt:searchDuplicateMethod(callback)
+ local uri = self.uri
+ local mark = {}
+ local map = {}
+ self.vm:eachSource(function (source)
+ local parent = source:get 'parent'
+ if not parent then
+ return
+ end
+ if mark[parent] then
+ return
+ end
+ mark[parent] = true
+ local relates = {}
+ parent:eachInfo(function (info, src)
+ local k = info[1]
+ if info.type ~= 'set child' then
+ return
+ end
+ if type(k) ~= 'string' then
+ return
+ end
+ if src.start == 0 then
+ return
+ end
+ if not src:get 'object' then
+ return
+ end
+ if map[src] then
+ return
+ end
+ if not relates[k] then
+ relates[k] = map[src] or {
+ name = k,
+ }
+ end
+ map[src] = relates[k]
+ relates[k][#relates[k]+1] = {
+ start = src.start,
+ finish = src.finish,
+ uri = src.uri
+ }
+ end)
+ end)
+ for src, relate in pairs(map) do
+ if #relate > 1 and src.uri == uri then
+ callback(src.start, src.finish, relate.name, relate)
+ end
+ end
+end
+
+function mt:searchEmptyBlock(callback)
+ self.vm:eachSource(function (source)
+ -- 认为空repeat与空while是合法的
+ -- 要去vm中激活source
+ if source.type == 'if' then
+ for _, block in ipairs(source) do
+ if #block > 0 then
+ return
+ end
+ end
+ callback(source.start, source.finish)
+ return
+ end
+ if source.type == 'loop'
+ or source.type == 'in'
+ then
+ if #source == 0 then
+ callback(source.start, source.finish)
+ end
+ return
+ end
+ end)
+end
+
+function mt:searchRedundantValue(callback)
+ self.vm:eachSource(function (source)
+ if source.type == 'set' or source.type == 'local' then
+ local args = source[1]
+ local values = source[2]
+ if not source[2] then
+ return
+ end
+ local argCount, valueCount
+ if args.type == 'list' then
+ argCount = #args
+ else
+ argCount = 1
+ end
+ if values.type == 'list' then
+ valueCount = #values
+ else
+ valueCount = 1
+ end
+ for i = argCount + 1, valueCount do
+ local value = values[i]
+ callback(value.start, value.finish, argCount, valueCount)
+ end
+ end
+ end)
+end
+
+function mt:searchUndefinedEnvChild(callback)
+ self.vm:eachSource(function (source)
+ if not source:get 'global' then
+ return
+ end
+ local name = source:getName()
+ if name == '' then
+ return
+ end
+ if source:get 'in index' then
+ return
+ end
+ local parent = source:get 'parent'
+ if parent:get 'ENV' then
+ return
+ end
+ local value = source:bindValue()
+ if not value then
+ return
+ end
+ if value:getSource() == source then
+ callback(source.start, source.finish, name)
+ end
+ return
+ end)
+end
+
+function mt:searchGlobalInNilEnv(callback)
+ self.vm:eachSource(function (source)
+ if not source:get 'global' then
+ return
+ end
+ local name = source:getName()
+ if name == '' then
+ return
+ end
+ local parentSource = source:get 'parent' :getSource()
+ if parentSource and parentSource.type == 'nil' then
+ callback(source.start, source.finish, {
+ {
+ start = parentSource.start,
+ finish = parentSource.finish,
+ uri = self.uri,
+ }
+ })
+ end
+ return
+ end)
+end
+
+function mt:checkEmmyClass(source, callback)
+ local class = source:get 'emmy.class'
+ if not class then
+ return
+ end
+ -- class重复定义
+ local name = class:getName()
+ local related = {}
+ self.vm.emmyMgr:eachClass(name, function (class)
+ if class.type ~= 'emmy.class' and class.type ~= 'emmy.alias' then
+ return
+ end
+ local src = class:getSource()
+ if src ~= source then
+ related[#related+1] = {
+ start = src.start,
+ finish = src.finish,
+ uri = src.uri,
+ }
+ end
+ end)
+ if #related > 0 then
+ callback(source.start, source.finish, lang.script.DIAG_DUPLICATE_CLASS ,related)
+ end
+ -- 继承不存在的class
+ local extends = class.extends
+ if not extends then
+ return
+ end
+ local parent = self.vm.emmyMgr:eachClass(extends, function (parent)
+ if parent.type == 'emmy.class' then
+ return parent
+ end
+ end)
+ if not parent then
+ callback(source[2].start, source[2].finish, lang.script.DIAG_UNDEFINED_CLASS)
+ return
+ end
+
+ -- class循环继承
+ local related = {}
+ local current = class
+ for _ = 1, 10 do
+ local extends = current.extends
+ if not extends then
+ break
+ end
+ related[#related+1] = {
+ start = current:getSource().start,
+ finish = current:getSource().finish,
+ uri = current:getSource().uri,
+ }
+ current = self.vm.emmyMgr:eachClass(extends, function (parent)
+ if parent.type == 'emmy.class' then
+ return parent
+ end
+ end)
+ if not current then
+ break
+ end
+ if current:getName() == class:getName() then
+ callback(source.start, source.finish, lang.script.DIAG_CYCLIC_EXTENDS, related)
+ break
+ end
+ end
+end
+
+function mt:checkEmmyType(source, callback)
+ for _, tpsource in ipairs(source) do
+ local name = tpsource[1]
+ local class = self.vm.emmyMgr:eachClass(name, function (class)
+ if class.type == 'emmy.class' or class.type == 'emmy.alias' then
+ return class
+ end
+ end)
+ if not class then
+ callback(tpsource.start, tpsource.finish, lang.script.DIAG_UNDEFINED_CLASS)
+ end
+ end
+end
+
+function mt:checkEmmyAlias(source, callback)
+ local class = source:get 'emmy.alias'
+ if not class then
+ return
+ end
+ -- class重复定义
+ local name = class:getName()
+ local related = {}
+ self.vm.emmyMgr:eachClass(name, function (class)
+ if class.type ~= 'emmy.class' and class.type ~= 'emmy.alias' then
+ return
+ end
+ local src = class:getSource()
+ if src ~= source then
+ related[#related+1] = {
+ start = src.start,
+ finish = src.finish,
+ uri = src.uri,
+ }
+ end
+ end)
+ if #related > 0 then
+ callback(source.start, source.finish, lang.script.DIAG_DUPLICATE_CLASS ,related)
+ end
+end
+
+function mt:checkEmmyParam(source, callback, mark)
+ local func = source:get 'emmy function'
+ if not func then
+ return
+ end
+ if mark[func] then
+ return
+ end
+ mark[func] = true
+
+ -- 检查不存在的参数
+ local emmyParams = func:getEmmyParams()
+ local funcParams = {}
+ if func.args then
+ for _, arg in ipairs(func.args) do
+ funcParams[arg.name] = true
+ end
+ end
+ for _, param in ipairs(emmyParams) do
+ local name = param:getName()
+ if not funcParams[name] then
+ callback(param:getSource()[1].start, param:getSource()[1].finish, lang.script.DIAG_INEXISTENT_PARAM)
+ end
+ end
+
+ -- 检查重复的param
+ local lists = {}
+ for _, param in ipairs(emmyParams) do
+ local name = param:getName()
+ if not lists[name] then
+ lists[name] = {}
+ end
+ lists[name][#lists[name]+1] = param:getSource()[1]
+ end
+ for _, list in pairs(lists) do
+ if #list > 1 then
+ local related = {}
+ for _, src in ipairs(list) do
+ related[#related+1] = {
+ src.start,
+ src.finish,
+ src.uri,
+ }
+ callback(src.start, src.finish, lang.script.DIAG_DUPLICATE_PARAM)
+ end
+ end
+ end
+end
+
+function mt:checkEmmyField(source, callback, mark)
+ ---@type EmmyClass
+ local class = source:get 'target class'
+ -- 必须写在 class 的后面
+ if not class then
+ callback(source.start, source.finish, lang.script.DIAG_NEED_CLASS)
+ end
+
+ -- 检查重复的 field
+ if class and not mark[class] then
+ mark[class] = true
+ local lists = {}
+ class:eachField(function (field)
+ local name = field:getName()
+ if not lists[name] then
+ lists[name] = {}
+ end
+ lists[name][#lists[name]+1] = field:getSource()[2]
+ end)
+ for _, list in pairs(lists) do
+ if #list > 1 then
+ local related = {}
+ for _, src in ipairs(list) do
+ related[#related+1] = {
+ src.start,
+ src.finish,
+ src.uri,
+ }
+ callback(src.start, src.finish, lang.script.DIAG_DUPLICATE_FIELD)
+ end
+ end
+ end
+ end
+end
+
+function mt:searchEmmyLua(callback)
+ local mark = {}
+ self.vm:eachSource(function (source)
+ if source.type == 'emmyClass' then
+ self:checkEmmyClass(source, callback)
+ elseif source.type == 'emmyType' then
+ self:checkEmmyType(source, callback)
+ elseif source.type == 'emmyAlias' then
+ self:checkEmmyAlias(source, callback)
+ elseif source.type == 'emmyParam' then
+ self:checkEmmyParam(source, callback, mark)
+ elseif source.type == 'emmyField' then
+ self:checkEmmyField(source, callback, mark)
+ end
+ end)
+end
+
+function mt:searchSetConstLocal(callback)
+ local mark = {}
+ self.vm:eachSource(function (source)
+ local loc = source:bindLocal()
+ if not loc then
+ return
+ end
+ if mark[loc] then
+ return
+ end
+ mark[loc] = true
+ if not loc.tags then
+ return
+ end
+ local const
+ for _, tag in ipairs(loc.tags) do
+ if tag[1] == 'const' then
+ const = true
+ break
+ end
+ end
+ if not const then
+ return
+ end
+ loc:eachInfo(function (info, src)
+ if info.type == 'set' then
+ callback(src.start, src.finish)
+ end
+ end)
+ end)
+end
+
+function mt:doDiagnostics(func, code, callback)
+ if config.config.diagnostics.disable[code] then
+ return
+ end
+ local level = config.config.diagnostics.severity[code]
+ if not DiagnosticSeverity[level] then
+ level = DiagnosticDefaultSeverity[code]
+ end
+ func(self, function (start, finish, ...)
+ local data = callback(...)
+ data.code = code
+ data.start = start
+ data.finish = finish
+ data.level = data.level or DiagnosticSeverity[level]
+ self.datas[#self.datas+1] = data
+ end)
+ if coroutine.isyieldable() then
+ if self.vm:isRemoved() then
+ coroutine.yield('stop')
+ else
+ coroutine.yield()
+ end
+ end
+end
+
+return function (vm, lines, uri)
+ local session = setmetatable({
+ vm = vm,
+ lines = lines,
+ uri = uri,
+ datas = {},
+ }, mt)
+
+ -- 未使用的局部变量
+ session:doDiagnostics(session.searchUnusedLocals, 'unused-local', function (key)
+ return {
+ message = lang.script('DIAG_UNUSED_LOCAL', key),
+ tags = {DiagnosticTag.Unnecessary},
+ }
+ end)
+ -- 未使用的函数
+ session:doDiagnostics(session.searchUnusedFunctions, 'unused-function', function ()
+ return {
+ message = lang.script.DIAG_UNUSED_FUNCTION,
+ tags = {DiagnosticTag.Unnecessary},
+ }
+ end)
+ -- 读取未定义全局变量
+ session:doDiagnostics(session.searchUndefinedGlobal, 'undefined-global', function (key)
+ local message = lang.script('DIAG_UNDEF_GLOBAL', key)
+ local otherVersion = library.other[key]
+ local customLib = library.custom[key]
+ if otherVersion then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(otherVersion, '/'), config.config.runtime.version))
+ end
+ if customLib then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_CUSTOM', table.concat(customLib, '/')))
+ end
+ return {
+ message = message,
+ }
+ end)
+ -- 未使用的Label
+ session:doDiagnostics(session.searchUnusedLabel, 'unused-label', function (key)
+ return {
+ message = lang.script('DIAG_UNUSED_LABEL', key),
+ tags = {DiagnosticTag.Unnecessary},
+ }
+ end)
+ -- 未使用的不定参数
+ session:doDiagnostics(session.searchUnusedVararg, 'unused-vararg', function ()
+ return {
+ message = lang.script.DIAG_UNUSED_VARARG,
+ tags = {DiagnosticTag.Unnecessary},
+ }
+ end)
+ -- 只有空格与制表符的行,以及后置空格
+ session:doDiagnostics(session.searchSpaces, 'trailing-space', function (message)
+ return {
+ message = message,
+ }
+ end)
+ -- 重定义局部变量
+ session:doDiagnostics(session.searchRedefinition, 'redefined-local', function (key, related)
+ return {
+ message = lang.script('DIAG_REDEFINED_LOCAL', key),
+ related = related,
+ }
+ end)
+ -- 以括号开始的一行(可能被误解析为了上一行的call)
+ session:doDiagnostics(session.searchNewLineCall, 'newline-call', function ()
+ return {
+ message = lang.script.DIAG_PREVIOUS_CALL,
+ }
+ end)
+ -- 以字符串开始的field(可能被误解析为了上一行的call)
+ session:doDiagnostics(session.searchNewFieldCall, 'newfield-call', function (func, call)
+ return {
+ message = lang.script('DIAG_PREFIELD_CALL', func, call),
+ }
+ end)
+ -- 调用函数时的参数数量是否超过函数的接收数量
+ session:doDiagnostics(session.searchRedundantParameters, 'redundant-parameter', function (max, passed)
+ return {
+ message = lang.script('DIAG_OVER_MAX_ARGS', max, passed),
+ tags = {DiagnosticTag.Unnecessary},
+ }
+ end)
+ -- x or 0 + 1
+ session:doDiagnostics(session.searchAmbiguity1, 'ambiguity-1', function (start, finish)
+ return {
+ message = lang.script('DIAG_AMBIGUITY_1', lines.buf:sub(start, finish)),
+ }
+ end)
+ -- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删)
+ session:doDiagnostics(session.searchLowercaseGlobal, 'lowercase-global', function ()
+ return {
+ message = lang.script.DIAG_LOWERCASE_GLOBAL,
+ }
+ end)
+ -- 未定义的变量(重载了 `_ENV`)
+ session:doDiagnostics(session.searchUndefinedEnvChild, 'undefined-env-child', function (key)
+ if vm.envType == '_ENV' then
+ return {
+ message = lang.script('DIAG_UNDEF_ENV_CHILD', key),
+ }
+ else
+ return {
+ message = lang.script('DIAG_UNDEF_FENV_CHILD', key),
+ }
+ end
+ end)
+ -- 全局变量不可用(置空了 `_ENV`)
+ session:doDiagnostics(session.searchGlobalInNilEnv, 'global-in-nil-env', function (related)
+ if vm.envType == '_ENV' then
+ return {
+ message = lang.script.DIAG_GLOBAL_IN_NIL_ENV,
+ related = related,
+ }
+ else
+ return {
+ message = lang.script.DIAG_GLOBAL_IN_NIL_FENV,
+ related = related,
+ }
+ end
+ end)
+ -- 构建表时重复定义field
+ session:doDiagnostics(session.searchDuplicateIndex, 'duplicate-index', function (key, related, type)
+ if type == 'unused' then
+ return {
+ message = lang.script('DIAG_DUPLICATE_INDEX', key),
+ related = related,
+ level = DiagnosticSeverity.Hint,
+ tags = {DiagnosticTag.Unnecessary},
+ }
+ else
+ return {
+ message = lang.script('DIAG_DUPLICATE_INDEX', key),
+ related = related,
+ }
+ end
+ end)
+ -- 往表里面塞重复的method
+ --session:doDiagnostics(session.searchDuplicateMethod, 'duplicate-method', function (key, related)
+ -- return {
+ -- message = lang.script('DIAG_DUPLICATE_METHOD', key),
+ -- related = related,
+ -- }
+ --end)
+ -- 空代码块
+ session:doDiagnostics(session.searchEmptyBlock, 'empty-block', function ()
+ return {
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ tags = {DiagnosticTag.Unnecessary},
+ }
+ end)
+ -- 多余的赋值
+ session:doDiagnostics(session.searchRedundantValue, 'redundant-value', function (max, passed)
+ return {
+ message = lang.script('DIAG_OVER_MAX_VALUES', max, passed),
+ tags = {DiagnosticTag.Unnecessary},
+ }
+ end)
+ -- Emmy相关的检查
+ session:doDiagnostics(session.searchEmmyLua, 'emmy-lua', function (message, related)
+ return {
+ message = message,
+ related = related,
+ }
+ end)
+ -- 检查给const变量赋值
+ session:doDiagnostics(session.searchSetConstLocal, 'set-const', function ()
+ return {
+ message = lang.script.DIAG_SET_CONST
+ }
+ end)
+ return session.datas
+end
diff --git a/script/core/document_symbol.lua b/script/core/document_symbol.lua
new file mode 100644
index 00000000..48e01332
--- /dev/null
+++ b/script/core/document_symbol.lua
@@ -0,0 +1,260 @@
+local hoverFunction = require 'core.hover.function'
+local getName = require 'core.name'
+local hover = require 'core.hover'
+
+local SymbolKind = {
+ File = 1,
+ Module = 2,
+ Namespace = 3,
+ Package = 4,
+ Class = 5,
+ Method = 6,
+ Property = 7,
+ Field = 8,
+ Constructor = 9,
+ Enum = 10,
+ Interface = 11,
+ Function = 12,
+ Variable = 13,
+ Constant = 14,
+ String = 15,
+ Number = 16,
+ Boolean = 17,
+ Array = 18,
+ Object = 19,
+ Key = 20,
+ Null = 21,
+ EnumMember = 22,
+ Struct = 23,
+ Event = 24,
+ Operator = 25,
+ TypeParameter = 26,
+}
+
+local function buildLocal(vm, source, used, callback)
+ local vars = source[1]
+ local exps = source[2]
+ if vars.type ~= 'list' then
+ vars = {vars}
+ end
+ if not exps or exps.type ~= 'list' then
+ exps = {exps}
+ end
+ for i, var in ipairs(vars) do
+ local exp = exps[i]
+ local data = {}
+ local loc = var:bindLocal()
+ data.name = loc:getName()
+ data.range = { var.start, var.finish }
+ data.selectionRange = { var.start, var.finish }
+ if exp then
+ local hvr = hover(var)
+ if exp.type == 'function' then
+ data.kind = SymbolKind.Function
+ else
+ data.kind = SymbolKind.Variable
+ end
+ data.detail = hvr.label:gsub('[\r\n]', '')
+ data.valueRange = { exp.start, exp.finish }
+ used[exp] = true
+ else
+ data.kind = SymbolKind.Variable
+ data.detail = ''
+ data.valueRange = { var.start, var.finish }
+ end
+ callback(data)
+ end
+end
+
+local function buildSet(vm, source, used, callback)
+ local vars = source[1]
+ local exps = source[2]
+ if vars.type ~= 'list' then
+ vars = {vars}
+ end
+ if not exps or exps.type ~= 'list' then
+ exps = {exps}
+ end
+ for i, var in ipairs(vars) do
+ if var:bindLocal() then
+ goto CONTINUE
+ end
+ local exp = exps[i]
+ local data = {}
+ data.name = getName(var)
+ data.range = { var.start, var.finish }
+ data.selectionRange = { var.start, var.finish }
+ if exp then
+ local hvr = hover(var)
+ if not hvr then
+ goto CONTINUE
+ end
+ if exp.type == 'function' then
+ data.kind = SymbolKind.Function
+ else
+ data.kind = SymbolKind.Property
+ end
+ data.detail = hvr.label:gsub('[\r\n]', '')
+ data.valueRange = { exp.start, exp.finish }
+ used[exp] = true
+ else
+ data.kind = SymbolKind.Property
+ data.detail = ''
+ data.valueRange = { var.start, var.finish }
+ end
+ callback(data)
+ :: CONTINUE ::
+ end
+end
+
+local function buildPair(vm, source, used, callback)
+ local var = source[1]
+ local exp = source[2]
+ local data = {}
+ data.name = getName(var)
+ data.range = { var.start, var.finish }
+ data.selectionRange = { var.start, var.finish }
+ if exp then
+ local hvr = hover(var)
+ if not hvr then
+ return
+ end
+ if exp.type == 'function' then
+ data.kind = SymbolKind.Function
+ else
+ data.kind = SymbolKind.Class
+ end
+ data.detail = hvr.label:gsub('[\r\n]', '')
+ data.valueRange = { exp.start, exp.finish }
+ used[exp] = true
+ else
+ data.kind = SymbolKind.Class
+ data.detail = ''
+ data.valueRange = { var.start, var.finish }
+ end
+ callback(data)
+end
+
+local function buildLocalFunction(vm, source, used, callback)
+ local value = source:bindFunction()
+ if not value then
+ return
+ end
+ local name = getName(source.name)
+ local hvr = hoverFunction(name, value:getFunction())
+ if not hvr then
+ return
+ end
+ local kind = SymbolKind.Function
+ callback {
+ name = name,
+ detail = hvr.label:gsub('[\r\n]', ''),
+ kind = kind,
+ range = { source.start, source.finish },
+ selectionRange = { source.name.start, source.name.finish },
+ valueRange = { source.start, source.finish },
+ }
+end
+
+
+local function buildFunction(vm, source, used, callback)
+ if used[source] then
+ return
+ end
+ local value = source:bindFunction()
+ if not value then
+ return
+ end
+ local name = getName(source.name)
+ local func = value:getFunction()
+ if not func then
+ return
+ end
+ local hvr = hoverFunction(name, func, func:getObject())
+ if not hvr then
+ return
+ end
+ local data = {}
+ data.name = name
+ data.detail = hvr.label:gsub('[\r\n]', '')
+ data.range = { source.start, source.finish }
+ data.valueRange = { source.start, source.finish }
+ if source.name then
+ data.selectionRange = { source.name.start, source.name.finish }
+ else
+ data.selectionRange = { source.start, source.start }
+ end
+ if func:getObject() then
+ data.kind = SymbolKind.Field
+ else
+ data.kind = SymbolKind.Function
+ end
+ callback(data)
+end
+
+local function buildSource(vm, source, used, callback)
+ if source.type == 'local' then
+ buildLocal(vm, source, used, callback)
+ return
+ end
+ if source.type == 'set' then
+ buildSet(vm, source, used, callback)
+ return
+ end
+ if source.type == 'pair' then
+ buildPair(vm, source, used, callback)
+ return
+ end
+ if source.type == 'localfunction' then
+ buildLocalFunction(vm, source, used, callback)
+ return
+ end
+ if source.type == 'function' then
+ buildFunction(vm, source, used, callback)
+ return
+ end
+end
+
+local function packChild(symbols, finish, kind)
+ local t
+ while true do
+ local symbol = symbols[#symbols]
+ if not symbol then
+ break
+ end
+ if symbol.valueRange[1] > finish then
+ break
+ end
+ symbols[#symbols] = nil
+ symbol.children = packChild(symbols, symbol.valueRange[2], symbol.kind)
+ if not t then
+ t = {}
+ end
+ t[#t+1] = symbol
+ end
+ return t
+end
+
+local function packSymbols(symbols)
+ -- 按照start位置反向排序
+ table.sort(symbols, function (a, b)
+ return a.range[1] > b.range[1]
+ end)
+ -- 处理嵌套
+ return packChild(symbols, math.maxinteger, SymbolKind.Function)
+end
+
+return function (vm)
+ local symbols = {}
+ local used = {}
+
+ vm:eachSource(function (source)
+ buildSource(vm, source, used, function (data)
+ symbols[#symbols+1] = data
+ end)
+ end)
+
+ local packedSymbols = packSymbols(symbols)
+
+ return packedSymbols
+end
diff --git a/script/core/find_lib.lua b/script/core/find_lib.lua
new file mode 100644
index 00000000..e76549a8
--- /dev/null
+++ b/script/core/find_lib.lua
@@ -0,0 +1,65 @@
+local hoverName = require 'core.hover.name'
+
+local function getParentName(lib, isObject)
+ for _, parent in ipairs(lib.parent) do
+ if isObject then
+ if parent.type == 'object' then
+ return parent.nick or parent.name
+ end
+ else
+ if parent.type ~= 'object' then
+ return parent.nick or parent.name
+ end
+ end
+ end
+ return ''
+end
+
+local function findLib(source)
+ local value = source:bindValue()
+ local lib = value:getLib()
+ if not lib then
+ return nil
+ end
+ if lib.parent then
+ if source:get 'object' then
+ -- *string:sub
+ local fullKey = ('*%s:%s'):format(getParentName(lib, true), lib.name)
+ return lib, fullKey
+ else
+ local parentValue = source:get 'parent'
+ if parentValue and parentValue:getType() == 'string' then
+ -- *string.sub
+ local fullKey = ('*%s.%s'):format(getParentName(lib, false), lib.name)
+ return lib, fullKey
+ else
+ -- string.sub
+ local fullKey = ('%s.%s'):format(getParentName(lib, false), lib.name)
+ return lib, fullKey
+ end
+ end
+ else
+ local name = hoverName(source)
+ local libName = lib.nick or lib.name
+ if name == libName or not libName then
+ return lib, name
+ elseif name == '' then
+ return lib, libName
+ else
+ return lib, ('%s<%s>'):format(name, libName)
+ end
+ end
+end
+
+return function (source)
+ if source:bindValue() then
+ local lib, fullKey = findLib(source)
+ return lib, fullKey
+ end
+ if source:get 'in index' then
+ source = source:get 'in index'
+ local lib, fullKey = findLib(source)
+ return lib, fullKey
+ end
+ return nil
+end
diff --git a/script/core/find_source.lua b/script/core/find_source.lua
new file mode 100644
index 00000000..a64a047e
--- /dev/null
+++ b/script/core/find_source.lua
@@ -0,0 +1,59 @@
+local function isContainPos(obj, pos)
+ if obj.start <= pos and obj.finish >= pos then
+ return true
+ end
+ return false
+end
+
+local function isValidSource(source)
+ return source.start ~= nil and source.start ~= 0
+end
+
+local function matchFilter(source, filter)
+ if not filter then
+ return true
+ end
+ return filter[source.type]
+end
+
+local function findAtPos(vm, pos, filter)
+ local res = {}
+ vm:eachSource(function (source)
+ if isValidSource(source)
+ and isContainPos(source, pos)
+ and matchFilter(source, filter)
+ then
+ res[#res+1] = source
+ end
+ end)
+ if #res == 0 then
+ return nil
+ end
+ table.sort(res, function (a, b)
+ if a == b then
+ return false
+ end
+ local rangeA = a.finish - a.start
+ local rangeB = b.finish - b.start
+ -- 特殊处理:func 'str' 的情况下,list与string的范围会完全相同,此时取string
+ if rangeA == rangeB then
+ if b.type == 'call' and #b == 1 and b[1] == a then
+ return true
+ elseif a.type == 'call' and #a == 1 and a[1] == b then
+ return false
+ else
+ return a.id < b.id
+ end
+ end
+ return rangeA < rangeB
+ end)
+ local source = res[1]
+ if not source then
+ return nil
+ end
+ return source
+end
+
+return function (vm, pos, filter)
+ return findAtPos(vm, pos, filter)
+end
diff --git a/script/core/folding_range.lua b/script/core/folding_range.lua
new file mode 100644
index 00000000..e94d1ffe
--- /dev/null
+++ b/script/core/folding_range.lua
@@ -0,0 +1,73 @@
+local foldingType = {
+ ['function'] = {'region', 'end', },
+ ['localfunction'] = {'region', 'end', },
+ ['do'] = {'region', 'end', },
+ ['if'] = {'region', 'end', },
+ ['loop'] = {'region', 'end', },
+ ['in'] = {'region', 'end', },
+ ['while'] = {'region', 'end', },
+ ['repeat'] = {'region', 'until',},
+ ['table'] = {'region', '}', },
+ ['string'] = {'regtion', ']', },
+}
+
+return function (vm, comments)
+ local result = {}
+ vm:eachSource(function (source)
+ local tp = source.type
+ local data = foldingType[tp]
+ if not data then
+ return
+ end
+ local start = source.start
+ local finish = source.finish
+ if tp == 'repeat' then
+ if #source > 0 then
+ finish = source[#source].finish
+ else
+ finish = start + #'repeat'
+ end
+ finish = vm.text:find('until', finish, true) or finish
+ result[#result+1] = {
+ start = start,
+ finish = finish,
+ kind = data[1],
+ }
+ elseif tp == 'if' then
+ for i = 1, #source do
+ local block = source[i]
+ local nblock = source[i+1]
+ result[#result+1] = {
+ start = block.start,
+ finish = nblock and nblock.start or finish,
+ kind = data[1],
+ }
+ end
+ elseif tp == 'string' then
+ result[#result+1] = {
+ start = start,
+ finish = finish,
+ kind = data[1],
+ }
+ elseif data[1] == 'region' then
+ result[#result+1] = {
+ start = start,
+ finish = finish,
+ kind = data[1],
+ }
+ end
+ end)
+ if comments then
+ for _, comment in ipairs(comments) do
+ result[#result+1] = {
+ start = comment.start,
+ finish = comment.finish,
+ kind = 'comment',
+ }
+ end
+ end
+ if #result == 0 then
+ return nil
+ end
+ return result
+end
diff --git a/script/core/global.lua b/script/core/global.lua
new file mode 100644
index 00000000..961ad304
--- /dev/null
+++ b/script/core/global.lua
@@ -0,0 +1,49 @@
+local mt = {}
+mt.__index = mt
+
+function mt:markSet(uri)
+ if not uri then
+ return
+ end
+ self.set[uri] = true
+end
+
+function mt:markGet(uri)
+ if not uri then
+ return
+ end
+ self.get[uri] = true
+end
+
+function mt:clearGlobal(uri)
+ self.set[uri] = nil
+ self.get[uri] = nil
+end
+
+function mt:getAllUris()
+ local uris = {}
+ for uri in pairs(self.set) do
+ uris[#uris+1] = uri
+ end
+ for uri in pairs(self.get) do
+ if not self.set[uri] then
+ uris[#uris+1] = uri
+ end
+ end
+ return uris
+end
+
+function mt:hasSetGlobal(uri)
+ return self.set[uri] ~= nil
+end
+
+function mt:remove()
+end
+
+return function (lsp)
+ return setmetatable({
+ get = {},
+ set = {},
+ lsp = lsp,
+ }, mt)
+end
diff --git a/script/core/highlight.lua b/script/core/highlight.lua
new file mode 100644
index 00000000..2073573d
--- /dev/null
+++ b/script/core/highlight.lua
@@ -0,0 +1,54 @@
+local findSource = require 'core.find_source'
+local parser = require 'parser'
+
+local DocumentHighlightKind = {
+ Text = 1,
+ Read = 2,
+ Write = 3,
+}
+
+local function parseResult(source)
+ local positions = {}
+ if source:bindLabel() then
+ source:bindLabel():eachInfo(function (info, src)
+ positions[#positions+1] = { src.start, src.finish, DocumentHighlightKind.Text }
+ end)
+ return positions
+ end
+ if source:bindLocal() then
+ local loc = source:bindLocal()
+ local mark = {}
+ loc:eachInfo(function (info, src)
+ if not mark[src] then
+ mark[src] = info
+ positions[#positions+1] = { src.start, src.finish, DocumentHighlightKind.Text }
+ end
+ end)
+ return positions
+ end
+ if source:bindValue() and source:get 'parent' then
+ local parent = source:get 'parent'
+ local mark = {}
+ parent:eachInfo(function (info, src)
+ if not mark[src] and source.uri == src.uri then
+ mark[src] = info
+ if info.type == 'get child' or info.type == 'set child' then
+ if info[1] == source[1] then
+ positions[#positions+1] = {src.start, src.finish, DocumentHighlightKind.Text}
+ end
+ end
+ end
+ end)
+ return positions
+ end
+ return nil
+end
+
+return function (vm, pos)
+ local source = findSource(vm, pos)
+ if not source then
+ return nil
+ end
+ local positions = parseResult(source)
+ return positions
+end
diff --git a/script/core/hover/emmy_function.lua b/script/core/hover/emmy_function.lua
new file mode 100644
index 00000000..7c87954e
--- /dev/null
+++ b/script/core/hover/emmy_function.lua
@@ -0,0 +1,143 @@
+---@param emmy EmmyFunctionType
+local function buildEmmyArgs(emmy, object, select)
+ local start
+ if object then
+ start = 2
+ else
+ start = 1
+ end
+ local strs = {}
+ local args = {}
+ local i = 0
+ emmy:eachParam(function (name, typeObj)
+ i = i + 1
+ if i < start then
+ return
+ end
+ if i > start then
+ strs[#strs+1] = ', '
+ end
+ if i == select then
+ strs[#strs+1] = '@ARG'
+ end
+ strs[#strs+1] = name .. ': ' .. typeObj:getType()
+ args[#args+1] = strs[#strs]
+ if i == select then
+ strs[#strs+1] = '@ARG'
+ end
+ end)
+ local text = table.concat(strs)
+ local argLabel = {}
+ for i = 1, 2 do
+ local pos = text:find('@ARG', 1, true)
+ if pos then
+ if i == 1 then
+ argLabel[i] = pos
+ else
+ argLabel[i] = pos - 1
+ end
+ text = text:sub(1, pos-1) .. text:sub(pos+4)
+ end
+ end
+ if #argLabel == 0 then
+ argLabel = nil
+ end
+ return text, argLabel, args
+end
+
+local function buildEmmyReturns(emmy)
+ local rtns = {}
+ local i = 0
+ emmy:eachReturn(function (rtn)
+ i = i + 1
+ if i > 1 then
+ rtns[#rtns+1] = ('\n% 3d. '):format(i)
+ end
+ rtns[#rtns+1] = rtn:getType()
+ end)
+ if #rtns == 0 then
+ return '\n -> ' .. 'any'
+ else
+ return '\n -> ' .. table.concat(rtns)
+ end
+end
+
+local function buildEnum(lib)
+ if not lib.enums then
+ return ''
+ end
+ local container = table.container()
+ for _, enum in ipairs(lib.enums) do
+ if not enum.name or (not enum.enum and not enum.code) then
+ goto NEXT_ENUM
+ end
+ if not container[enum.name] then
+ container[enum.name] = {}
+ if lib.args then
+ for _, arg in ipairs(lib.args) do
+ if arg.name == enum.name then
+ container[enum.name].type = arg.type
+ break
+ end
+ end
+ end
+ if lib.returns then
+ for _, rtn in ipairs(lib.returns) do
+ if rtn.name == enum.name then
+ container[enum.name].type = rtn.type
+ break
+ end
+ end
+ end
+ end
+ table.insert(container[enum.name], enum)
+ ::NEXT_ENUM::
+ end
+ local strs = {}
+ local raw = {}
+ for name, enums in pairs(container) do
+ local tp
+ if type(enums.type) == 'table' then
+ tp = table.concat(enums.type, '/')
+ else
+ tp = enums.type
+ end
+ raw[name] = {}
+ strs[#strs+1] = ('\n%s: %s'):format(name, tp or 'any')
+ for _, enum in ipairs(enums) do
+ if enum.default then
+ strs[#strs+1] = '\n -> '
+ else
+ strs[#strs+1] = '\n | '
+ end
+ if enum.code then
+ strs[#strs+1] = tostring(enum.code)
+ else
+ strs[#strs+1] = ('%q'):format(enum.enum)
+ end
+ raw[name][#raw[name]+1] = strs[#strs]
+ if enum.description then
+ strs[#strs+1] = ' -- ' .. enum.description
+ end
+ end
+ end
+ return table.concat(strs), raw
+end
+
+return function (name, emmy, object, select)
+ local argStr, argLabel, args = buildEmmyArgs(emmy, object, select)
+ local returns = buildEmmyReturns(emmy)
+ local enum, rawEnum = buildEnum(emmy)
+ local tip = emmy.description
+ return {
+ label = ('function %s(%s)%s'):format(name, argStr, returns),
+ name = name,
+ argStr = argStr,
+ returns = returns,
+ description = tip,
+ enum = enum,
+ rawEnum = rawEnum,
+ argLabel = argLabel,
+ args = args,
+ }
+end
diff --git a/script/core/hover/function.lua b/script/core/hover/function.lua
new file mode 100644
index 00000000..3865f602
--- /dev/null
+++ b/script/core/hover/function.lua
@@ -0,0 +1,243 @@
+local emmyFunction = require 'core.hover.emmy_function'
+
+local function buildValueArgs(func, object, select)
+ if not func then
+ return '', nil
+ end
+ local names = {}
+ local values = {}
+ local options = {}
+ if func.argValues then
+ for i, value in ipairs(func.argValues) do
+ values[i] = value:getType()
+ end
+ end
+ if func.args then
+ for i, arg in ipairs(func.args) do
+ names[#names+1] = arg:getName()
+ local param = func:findEmmyParamByName(arg:getName())
+ if param then
+ values[i] = param:getType()
+ options[i] = param:getOption()
+ end
+ end
+ end
+ local strs = {}
+ local start = 1
+ if object then
+ start = 2
+ end
+ local max
+ if func:getSource() then
+ max = #names
+ else
+ max = math.max(#names, #values)
+ end
+ local args = {}
+ for i = start, max do
+ local name = names[i]
+ local value = values[i] or 'any'
+ local option = options[i]
+ if option and option.optional then
+ if i > start then
+ strs[#strs+1] = ' ['
+ else
+ strs[#strs+1] = '['
+ end
+ end
+ if i > start then
+ strs[#strs+1] = ', '
+ end
+
+ if i == select then
+ strs[#strs+1] = '@ARG'
+ end
+ if name then
+ strs[#strs+1] = name .. ': ' .. value
+ else
+ strs[#strs+1] = value
+ end
+ args[#args+1] = strs[#strs]
+ if i == select then
+ strs[#strs+1] = '@ARG'
+ end
+
+ if option and option.optional == 'self' then
+ strs[#strs+1] = ']'
+ end
+ end
+ if func:hasDots() then
+ if max > 0 then
+ strs[#strs+1] = ', '
+ end
+ strs[#strs+1] = '...'
+ end
+
+ if options then
+ for _, option in pairs(options) do
+ if option.optional == 'after' then
+ strs[#strs+1] = ']'
+ end
+ end
+ end
+
+ local text = table.concat(strs)
+ local argLabel = {}
+ for i = 1, 2 do
+ local pos = text:find('@ARG', 1, true)
+ if pos then
+ if i == 1 then
+ argLabel[i] = pos
+ else
+ argLabel[i] = pos - 1
+ end
+ text = text:sub(1, pos-1) .. text:sub(pos+4)
+ end
+ end
+ if #argLabel == 0 then
+ argLabel = nil
+ end
+ return text, argLabel, args
+end
+
+local function buildValueReturns(func)
+ if not func then
+ return '\n -> any'
+ end
+ if not func:get 'hasReturn' then
+ return ''
+ end
+ local strs = {}
+ local emmys = {}
+ local n = 0
+ func:eachEmmyReturn(function (emmy)
+ n = n + 1
+ emmys[n] = emmy
+ end)
+ if func.returns then
+ for i, rtn in ipairs(func.returns) do
+ local emmy = emmys[i]
+ local option = emmy and emmy.option
+ if option and option.optional then
+ if i > 1 then
+ strs[#strs+1] = ' ['
+ else
+ strs[#strs+1] = '['
+ end
+ end
+ if i > 1 then
+ strs[#strs+1] = ('\n% 3d. '):format(i)
+ end
+ if emmy and emmy.name then
+ strs[#strs+1] = ('%s: '):format(emmy.name)
+ elseif option and option.name then
+ strs[#strs+1] = ('%s: '):format(option.name)
+ end
+ strs[#strs+1] = rtn:getType()
+ if option and option.optional == 'self' then
+ strs[#strs+1] = ']'
+ end
+ end
+ for i = 1, #func.returns do
+ local emmy = emmys[i]
+ if emmy and emmy.option and emmy.option.optional == 'after' then
+ strs[#strs+1] = ']'
+ end
+ end
+ end
+ if #strs == 0 then
+ strs[1] = 'any'
+ end
+ return '\n -> ' .. table.concat(strs)
+end
+
+---@param func emmyFunction
+local function buildEnum(func)
+ if not func then
+ return nil
+ end
+ local params = func:getEmmyParams()
+ if not params then
+ return nil
+ end
+ local strs = {}
+ local raw = {}
+ for _, param in ipairs(params) do
+ local first = true
+ local name = param:getName()
+ raw[name] = {}
+ param:eachEnum(function (enum)
+ if first then
+ first = false
+ strs[#strs+1] = ('\n%s: %s'):format(param:getName(), param:getType())
+ end
+ if enum.default then
+ strs[#strs+1] = ('\n |>%s'):format(enum[1])
+ else
+ strs[#strs+1] = ('\n | %s'):format(enum[1])
+ end
+ if enum.comment then
+ strs[#strs+1] = ' -- ' .. enum.comment
+ end
+ raw[name][#raw[name]+1] = enum[1]
+ end)
+ end
+ if #strs == 0 then
+ return nil
+ end
+ return table.concat(strs), raw
+end
+
+local function getComment(func)
+ if not func then
+ return nil
+ end
+ local comments = {}
+ local params = func:getEmmyParams()
+ if params then
+ for _, param in ipairs(params) do
+ local option = param:getOption()
+ if option and option.comment then
+ comments[#comments+1] = ('+ `%s`*(%s)*: %s'):format(param:getName(), param:getType(), option.comment)
+ end
+ end
+ end
+ comments[#comments+1] = func:getComment()
+ if #comments == 0 then
+ return nil
+ end
+ return table.concat(comments, '\n\n')
+end
+
+local function getOverLoads(name, func, object, select)
+ local overloads = func and func:getEmmyOverLoads()
+ if not overloads then
+ return nil
+ end
+ local list = {}
+ for _, ol in ipairs(overloads) do
+ local hover = emmyFunction(name, ol, object, select)
+ list[#list+1] = hover.label
+ end
+ return table.concat(list, '\n')
+end
+
+return function (name, func, object, select)
+ local argStr, argLabel, args = buildValueArgs(func, object, select)
+ local returns = buildValueReturns(func)
+ local enum, rawEnum = buildEnum(func)
+ local comment = getComment(func)
+ local overloads = getOverLoads(name, func, object, select)
+ return {
+ label = ('function %s(%s)%s'):format(name, argStr, returns),
+ name = name,
+ argStr = argStr,
+ returns = returns,
+ description = comment,
+ enum = enum,
+ rawEnum = rawEnum,
+ argLabel = argLabel,
+ overloads = overloads,
+ args = args,
+ }
+end
diff --git a/script/core/hover/hover.lua b/script/core/hover/hover.lua
new file mode 100644
index 00000000..2ee5cf46
--- /dev/null
+++ b/script/core/hover/hover.lua
@@ -0,0 +1,326 @@
+local findLib = require 'core.find_lib'
+local getFunctionHover = require 'core.hover.function'
+local getFunctionHoverAsLib = require 'core.hover.lib_function'
+local getFunctionHoverAsEmmy = require 'core.hover.emmy_function'
+local buildValueName = require 'core.hover.name'
+
+local OriginTypes = {
+ ['any'] = true,
+ ['nil'] = true,
+ ['integer'] = true,
+ ['number'] = true,
+ ['boolean'] = true,
+ ['string'] = true,
+ ['thread'] = true,
+ ['userdata'] = true,
+ ['table'] = true,
+ ['function'] = true,
+}
+
+local function longString(str)
+ for i = 0, 10 do
+ local finish = ']' .. ('='):rep(i) .. ']'
+ if not str:find(finish, 1, true) then
+ return ('[%s[\n%s%s'):format(('='):rep(i), str, finish)
+ end
+ end
+ return ('%q'):format(str)
+end
+
+local function formatString(str)
+ if #str > 1000 then
+ str = str:sub(1000)
+ end
+ if str:find('[\r\n]') then
+ str = str:gsub('[\000-\008\011-\012\014-\031\127]', '')
+ return longString(str)
+ else
+ str = str:gsub('[\000-\008\011-\012\014-\031\127]', function (char)
+ return ('\\%03d'):format(char:byte())
+ end)
+ local single = str:find("'", 1, true)
+ local double = str:find('"', 1, true)
+ if single and double then
+ return longString(str)
+ elseif double then
+ return ("'%s'"):format(str)
+ else
+ return ('"%s"'):format(str)
+ end
+ end
+end
+
+local function formatLiteral(v)
+ if math.type(v) == 'float' then
+ return ('%.10f'):format(v):gsub('[0]*$', ''):gsub('%.$', '.0')
+ elseif type(v) == 'string' then
+ return formatString(v)
+ else
+ return ('%q'):format(v)
+ end
+end
+
+local function findClass(value)
+ -- 检查是否有emmy
+ local emmy = value:getEmmy()
+ if emmy then
+ return emmy:getType()
+ end
+ -- 检查对象元表
+ local metaValue = value:getMetaTable()
+ if not metaValue then
+ return nil
+ end
+ -- 检查元表中的 __name
+ local metaName = metaValue:rawGet('__name')
+ if metaName and type(metaName:getLiteral()) == 'string' then
+ return metaName:getLiteral()
+ end
+ -- 检查元表的 __index
+ local indexValue = metaValue:rawGet('__index')
+ if not indexValue then
+ return nil
+ end
+ -- 查找index方法中的以下字段: type name class
+ -- 允许多重继承
+ return indexValue:eachChild(function (k, v)
+ -- 键值类型必须均为字符串
+ if type(k) ~= 'string' then
+ return
+ end
+ if type(v:getLiteral()) ~= 'string' then
+ return
+ end
+ local lKey = k:lower()
+ if lKey == 'type'
+ or lKey == 'name'
+ or lKey == 'class'
+ then
+ -- 必须只有过一次赋值
+ local hasSet = false
+ local ok = v:eachInfo(function (info)
+ if info.type == 'set' then
+ if hasSet then
+ return false
+ else
+ hasSet = true
+ end
+ end
+ end)
+ if ok == false then
+ return false
+ end
+ return v:getLiteral()
+ end
+ end)
+end
+
+local function formatKey(key)
+ local kType = type(key)
+ if kType == 'table' then
+ key = ('[*%s]'):format(key:getType())
+ elseif math.type(key) == 'integer' then
+ key = ('[%03d]'):format(key)
+ elseif kType == 'string' then
+ if key:find '^%d' or key:find '[^%w_]' then
+ key = ('[%s]'):format(formatString(key))
+ end
+ elseif key == '' then
+ key = '[*any]'
+ else
+ key = ('[%s]'):format(key)
+ end
+ return key
+end
+
+local function unpackTable(value)
+ local lines = {}
+ value:eachChild(function (key, child)
+ key = formatKey(key)
+
+ local vType = type(child:getLiteral())
+ if vType == 'boolean'
+ or vType == 'integer'
+ or vType == 'number'
+ or vType == 'string'
+ then
+ lines[#lines+1] = ('%s: %s = %s'):format(key, child:getType(), formatLiteral(child:getLiteral()))
+ else
+ lines[#lines+1] = ('%s: %s'):format(key, child:getType())
+ end
+ end)
+ local emmy = value:getEmmy()
+ if emmy then
+ if emmy.type == 'emmy.arrayType' then
+ lines[#lines+1] = ('[*integer]: %s'):format(emmy:getName())
+ elseif emmy.type == 'emmy.tableType' then
+ lines[#lines+1] = ('[*%s]: %s'):format(emmy:getKeyType():getType(), emmy:getValueType():getType())
+ end
+ end
+ if #lines == 0 then
+ return '{}'
+ end
+
+ -- 整理一下表
+ local cleaned = {}
+ local used = {}
+ for _, line in ipairs(lines) do
+ if used[line] then
+ goto CONTINUE
+ end
+ used[line] = true
+ if line == '[*any]: any' then
+ goto CONTINUE
+ end
+ cleaned[#cleaned+1] = ' ' .. line .. ','
+ :: CONTINUE ::
+ end
+
+ table.sort(cleaned)
+ table.insert(cleaned, 1, '{')
+ cleaned[#cleaned+1] = '}'
+ return table.concat(cleaned, '\r\n')
+end
+
+local function getValueHover(source, name, value, lib)
+ local valueType = value:getType()
+ local class = findClass(value)
+
+ if class then
+ valueType = class
+ lib = nil
+ end
+
+ if not OriginTypes[valueType] then
+ valueType = '*' .. valueType
+ end
+
+ local tips = {}
+ local literal
+ if lib then
+ literal = lib.code or (lib.value and formatLiteral(lib.value))
+ tips[#tips+1] = lib.description
+ else
+ literal = value:getLiteral() and formatLiteral(value:getLiteral())
+ end
+
+ tips[#tips+1] = value:getComment()
+
+ local tp
+ if source:bindLocal() then
+ tp = 'local'
+ local loc = source:bindLocal()
+ if loc.tags then
+ local mark = {}
+ local tagBufs = {}
+ for _, tag in ipairs(loc.tags) do
+ local tagName = tag[1]
+ if not mark[tagName] then
+ mark[tagName] = true
+ tagBufs[#tagBufs+1] = ('<%s>'):format(tagName)
+ end
+ end
+ name = name .. ' ' .. table.concat(tagBufs, ' ')
+ end
+ tips[#tips+1] = loc:getComment()
+ elseif source:get 'global' then
+ tp = 'global'
+ elseif source:get 'simple' then
+ local simple = source:get 'simple'
+ if simple[1]:get 'global' then
+ tp = 'global'
+ else
+ tp = 'field'
+ end
+ else
+ tp = 'field'
+ end
+
+ local text
+ if valueType == 'table' then
+ text = ('%s %s: %s'):format(tp, name, unpackTable(value))
+ else
+ if literal == nil then
+ if class and not OriginTypes[class] then
+ text = ('%s %s: %s %s'):format(tp, name, valueType, unpackTable(value))
+ else
+ text = ('%s %s: %s'):format(tp, name, valueType)
+ end
+ else
+ text = ('%s %s: %s = %s'):format(tp, name, valueType, literal)
+ end
+ end
+
+ local tip
+ if #tips > 0 then
+ tip = table.concat(tips, '\n\n-------------\n\n')
+ end
+ return {
+ label = text,
+ description = tip,
+ }
+end
+
+local function hoverAsValue(source, lsp, select)
+ local lib, fullkey = findLib(source)
+ ---@type value
+ local value = source:findValue()
+ local name = fullkey or buildValueName(source)
+
+ local hover
+ if value:getType() == 'function' then
+ local object = source:get 'object'
+ if lib then
+ hover = getFunctionHoverAsLib(name, lib, object, select)
+ else
+ local emmy = value:getEmmy()
+ if emmy and emmy.type == 'emmy.functionType' then
+ hover = getFunctionHoverAsEmmy(name, emmy, object, select)
+ else
+ local func = value:getFunction()
+ hover = getFunctionHover(name, func, object, select)
+ end
+ end
+ else
+ hover = getValueHover(source, name, value, lib)
+ end
+
+ if not hover then
+ return nil
+ end
+ hover.name = name
+ return hover
+end
+
+local function hoverAsTargetUri(source, lsp)
+ local uri = source:get 'target uri'
+ if not lsp or not lsp.workspace then
+ return nil
+ end
+ local path = lsp.workspace:relativePathByUri(uri)
+ if not path then
+ return nil
+ end
+ return {
+ description = ('[%s](%s)'):format(path:string(), uri),
+ }
+end
+
+return function (source, lsp, select)
+ if not source then
+ return nil
+ end
+ if source:get 'target uri' then
+ return hoverAsTargetUri(source, lsp)
+ end
+ if source.type == 'name' and source:bindValue() then
+ return hoverAsValue(source, lsp, select)
+ end
+ if source.type == 'simple' then
+ source = source[#source]
+ if source.type == 'name' and source:bindValue() then
+ return hoverAsValue(source, lsp, select)
+ end
+ end
+ return nil
+end
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
new file mode 100644
index 00000000..be5b5632
--- /dev/null
+++ b/script/core/hover/init.lua
@@ -0,0 +1 @@
+return require 'core.hover.hover'
diff --git a/script/core/hover/lib_function.lua b/script/core/hover/lib_function.lua
new file mode 100644
index 00000000..06087312
--- /dev/null
+++ b/script/core/hover/lib_function.lua
@@ -0,0 +1,222 @@
+local lang = require 'language'
+local config = require 'config'
+local function buildLibArgs(lib, object, select)
+ if not lib.args then
+ return ''
+ end
+ local start
+ if object then
+ start = 2
+ else
+ start = 1
+ end
+ local strs = {}
+ local args = {}
+ for i = start, #lib.args do
+ local arg = lib.args[i]
+ if arg.optional then
+ if i > start then
+ strs[#strs+1] = ' ['
+ else
+ strs[#strs+1] = '['
+ end
+ end
+ if i > start then
+ strs[#strs+1] = ', '
+ end
+
+ local argStr = {}
+ if i == select then
+ argStr[#argStr+1] = '@ARG'
+ end
+ local name = ''
+ if arg.name then
+ name = ('%s: '):format(arg.name)
+ end
+ if type(arg.type) == 'table' then
+ name = name .. table.concat(arg.type, '/')
+ else
+ name = name .. (arg.type or 'any')
+ end
+ argStr[#argStr+1] = name
+ args[#args+1] = name
+ if arg.default then
+ argStr[#argStr+1] = ('(%q)'):format(arg.default)
+ end
+ if i == select then
+ argStr[#argStr+1] = '@ARG'
+ end
+
+ for _, str in ipairs(argStr) do
+ strs[#strs+1] = str
+ end
+ if arg.optional == 'self' then
+ strs[#strs+1] = ']'
+ end
+ end
+ for _, arg in ipairs(lib.args) do
+ if arg.optional == 'after' then
+ strs[#strs+1] = ']'
+ end
+ end
+ local text = table.concat(strs)
+ local argLabel = {}
+ for i = 1, 2 do
+ local pos = text:find('@ARG', 1, true)
+ if pos then
+ if i == 1 then
+ argLabel[i] = pos
+ else
+ argLabel[i] = pos - 1
+ end
+ text = text:sub(1, pos-1) .. text:sub(pos+4)
+ end
+ end
+ if #argLabel == 0 then
+ argLabel = nil
+ end
+ return text, argLabel, args
+end
+
+local function buildLibReturns(lib)
+ if not lib.returns then
+ return ''
+ end
+ local strs = {}
+ for i, rtn in ipairs(lib.returns) do
+ if rtn.optional then
+ if i > 1 then
+ strs[#strs+1] = ' ['
+ else
+ strs[#strs+1] = '['
+ end
+ end
+ if i > 1 then
+ strs[#strs+1] = ('\n% 3d. '):format(i)
+ end
+ if rtn.name then
+ strs[#strs+1] = ('%s: '):format(rtn.name)
+ end
+ if type(rtn.type) == 'table' then
+ strs[#strs+1] = table.concat(rtn.type, '/')
+ else
+ strs[#strs+1] = rtn.type or 'any'
+ end
+ if rtn.default then
+ strs[#strs+1] = ('(%q)'):format(rtn.default)
+ end
+ if rtn.optional == 'self' then
+ strs[#strs+1] = ']'
+ end
+ end
+ for _, rtn in ipairs(lib.returns) do
+ if rtn.optional == 'after' then
+ strs[#strs+1] = ']'
+ end
+ end
+ return '\n -> ' .. table.concat(strs)
+end
+
+local function buildEnum(lib)
+ if not lib.enums then
+ return ''
+ end
+ local container = table.container()
+ for _, enum in ipairs(lib.enums) do
+ if not enum.name or (not enum.enum and not enum.code) then
+ goto NEXT_ENUM
+ end
+ if not container[enum.name] then
+ container[enum.name] = {}
+ if lib.args then
+ for _, arg in ipairs(lib.args) do
+ if arg.name == enum.name then
+ container[enum.name].type = arg.type
+ break
+ end
+ end
+ end
+ if lib.returns then
+ for _, rtn in ipairs(lib.returns) do
+ if rtn.name == enum.name then
+ container[enum.name].type = rtn.type
+ break
+ end
+ end
+ end
+ end
+ table.insert(container[enum.name], enum)
+ ::NEXT_ENUM::
+ end
+ local strs = {}
+ local raw = {}
+ for name, enums in pairs(container) do
+ local tp
+ if type(enums.type) == 'table' then
+ tp = table.concat(enums.type, '/')
+ else
+ tp = enums.type
+ end
+ strs[#strs+1] = ('\n%s: %s'):format(name, tp or 'any')
+ raw[name] = {}
+ for _, enum in ipairs(enums) do
+ if enum.default then
+ strs[#strs+1] = '\n -> '
+ else
+ strs[#strs+1] = '\n | '
+ end
+ if enum.code then
+ strs[#strs+1] = tostring(enum.code)
+ else
+ strs[#strs+1] = tostring(enum.enum)
+ end
+ raw[name][#raw[name]+1] = strs[#strs]
+ if enum.description then
+ strs[#strs+1] = ' -- ' .. enum.description
+ end
+ end
+ end
+ return table.concat(strs), raw
+end
+
+local function buildDoc(lib)
+ local doc = lib.doc
+ if not doc then
+ return
+ end
+ if lib.web then
+ return lang.script(lib.web, doc)
+ end
+ local version = config.config.runtime.version
+ if version == 'Lua 5.1' then
+ return lang.script('HOVER_DOCUMENT_LUA51', doc)
+ elseif version == 'Lua 5.2' then
+ return lang.script('HOVER_DOCUMENT_LUA52', doc)
+ elseif version == 'Lua 5.3' then
+ return lang.script('HOVER_DOCUMENT_LUA53', doc)
+ elseif version == 'Lua 5.4' then
+ return lang.script('HOVER_DOCUMENT_LUA54', doc)
+ elseif version == 'LuaJIT' then
+ return lang.script('HOVER_DOCUMENT_LUAJIT', doc)
+ end
+end
+
+return function (name, lib, object, select)
+ local argStr, argLabel, args = buildLibArgs(lib, object, select)
+ local returns = buildLibReturns(lib)
+ local enum, rawEnum = buildEnum(lib)
+ local tip = lib.description
+ local doc = buildDoc(lib)
+ return {
+ label = ('function %s(%s)%s'):format(name, argStr, returns),
+ name = name,
+ argStr = argStr,
+ returns = returns,
+ description = tip,
+ enum = enum,
+ rawEnum = rawEnum,
+ argLabel = argLabel,
+ doc = doc,
+ args = args,
+ }
+end
diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua
new file mode 100644
index 00000000..763083b9
--- /dev/null
+++ b/script/core/hover/name.lua
@@ -0,0 +1,38 @@
+local getName = require 'core.name'
+
+return function (source)
+ if not source then
+ return ''
+ end
+ local value = source:bindValue()
+ if not value then
+ return ''
+ end
+ local func = value:getFunction()
+ local declarat
+ if func and func:getSource() then
+ declarat = func:getSource().name
+ else
+ declarat = source
+ end
+ if not declarat then
+ -- 如果声明者没有给名字,则找一个合适的名字
+ local names = {}
+ value:eachInfo(function (info, src)
+ if info.type == 'local' or info.type == 'set' or info.type == 'return' then
+ if src.type == 'name' and src.uri == value.uri then
+ names[#names+1] = src
+ end
+ end
+ end)
+ if #names == 0 then
+ return ''
+ end
+ table.sort(names, function (a, b)
+ return a.id < b.id
+ end)
+ return names[1][1] or ''
+ end
+
+ return getName(declarat, source)
+end
diff --git a/script/core/implementation.lua b/script/core/implementation.lua
new file mode 100644
index 00000000..f51a97ca
--- /dev/null
+++ b/script/core/implementation.lua
@@ -0,0 +1,204 @@
+local function parseValueSimily(vm, source, lsp)
+ local key = source[1]
+ if not key then
+ return nil
+ end
+ local positions = {}
+ vm:eachSource(function (other)
+ if other == source then
+ return
+ end
+ if other[1] == key
+ and not other:bindLocal()
+ and other:bindValue()
+ and other:action() == 'set'
+ and source:bindValue() ~= other:bindValue()
+ then
+ positions[#positions+1] = {
+ other.start,
+ other.finish,
+ }
+ end
+ end)
+ if #positions == 0 then
+ return nil
+ end
+ return positions
+end
+
+local function parseValueCrossFile(vm, source, lsp)
+ local value = source:bindValue()
+ local positions = {}
+ value:eachInfo(function (info, src)
+ if info.type == 'local' and src.uri == value.uri then
+ positions[#positions+1] = {
+ src.start,
+ src.finish,
+ value.uri,
+ }
+ return true
+ end
+ end)
+ if #positions > 0 then
+ return positions
+ end
+
+ value:eachInfo(function (info, src)
+ if info.type == 'set' and src.uri == value.uri then
+ positions[#positions+1] = {
+ src.start,
+ src.finish,
+ value.uri,
+ }
+ end
+ end)
+ if #positions > 0 then
+ return positions
+ end
+
+ value:eachInfo(function (info, src)
+ if info.type == 'return' and src.uri == value.uri then
+ positions[#positions+1] = {
+ src.start,
+ src.finish,
+ value.uri,
+ }
+ end
+ end)
+ if #positions > 0 then
+ return positions
+ end
+
+ local destVM = lsp:getVM(value.uri)
+ if not destVM then
+ positions[#positions+1] = {
+ 0, 0, value.uri,
+ }
+ return positions
+ end
+
+ local result = parseValueSimily(destVM, source, lsp)
+ if result then
+ for _, position in ipairs(result) do
+ positions[#positions+1] = position
+ position[3] = value.uri
+ end
+ end
+ if #positions > 0 then
+ return positions
+ end
+
+ return positions
+end
+
+local function parseValue(vm, source, lsp)
+ local positions = {}
+ local mark = {}
+
+ local function callback(src)
+ if source == src then
+ return
+ end
+ if mark[src] then
+ return
+ end
+ mark[src] = true
+ if src.start == 0 then
+ return
+ end
+ local uri = src.uri
+ if uri == '' then
+ uri = nil
+ end
+ positions[#positions+1] = {
+ src.start,
+ src.finish,
+ uri,
+ }
+ end
+
+ if source:bindValue() then
+ source:bindValue():eachInfo(function (info, src)
+ if info.type == 'set' or info.type == 'local' or info.type == 'return' then
+ callback(src)
+ return true
+ end
+ end)
+ end
+ local parent = source:get 'parent'
+ if parent then
+ parent:eachInfo(function (info, src)
+ if info[1] == source[1] then
+ if info.type == 'set child' then
+ callback(src)
+ end
+ end
+ end)
+ end
+ if #positions == 0 then
+ return nil
+ end
+ return positions
+end
+
+local function parseLabel(vm, label, lsp)
+ local positions = {}
+ label:eachInfo(function (info, src)
+ if info.type == 'set' then
+ positions[#positions+1] = {
+ src.start,
+ src.finish,
+ }
+ end
+ end)
+ if #positions == 0 then
+ return nil
+ end
+ return positions
+end
+
+local function jumpUri(vm, source, lsp)
+ local uri = source:get 'target uri'
+ local positions = {}
+ positions[#positions+1] = {
+ 0, 0, uri,
+ }
+ return positions
+end
+
+local function parseClass(vm, source)
+ local className = source:get 'emmy class'
+ local positions = {}
+ vm.emmyMgr:eachClass(className, function (class)
+ local src = class:getSource()
+ positions[#positions+1] = {
+ src.start,
+ src.finish,
+ src.uri,
+ }
+ end)
+ return positions
+end
+
+return function (vm, source, lsp)
+ if not source then
+ return nil
+ end
+ if source:bindValue() then
+ return parseValue(vm, source, lsp)
+ or parseValueSimily(vm, source, lsp)
+ end
+ if source:bindLabel() then
+ return parseLabel(vm, source:bindLabel(), lsp)
+ end
+ if source:get 'target uri' then
+ return jumpUri(vm, source, lsp)
+ end
+ if source:get 'in index' then
+ return parseValue(vm, source, lsp)
+ or parseValueSimily(vm, source, lsp)
+ end
+ if source:get 'emmy class' then
+ return parseClass(vm, source)
+ end
+end
diff --git a/script/core/init.lua b/script/core/init.lua
new file mode 100644
index 00000000..213dbaca
--- /dev/null
+++ b/script/core/init.lua
@@ -0,0 +1,19 @@
+local api = {
+ definition = require 'core.definition',
+ implementation = require 'core.implementation',
+ references = require 'core.references',
+ rename = require 'core.rename',
+ hover = require 'core.hover',
+ diagnostics = require 'core.diagnostics',
+ findSource = require 'core.find_source',
+ findLib = require 'core.find_lib',
+ completion = require 'core.completion',
+ signature = require 'core.signature',
+ documentSymbol = require 'core.document_symbol',
+ global = require 'core.global',
+ highlight = require 'core.highlight',
+ codeAction = require 'core.code_action',
+ foldingRange = require 'core.folding_range',
+}
+
+return api
diff --git a/script/core/library.lua b/script/core/library.lua
new file mode 100644
index 00000000..d5edad66
--- /dev/null
+++ b/script/core/library.lua
@@ -0,0 +1,296 @@
+local lni = require 'lni'
+local fs = require 'bee.filesystem'
+local config = require 'config'
+
+local Library = {}
+
+local function mergeEnum(lib, locale)
+ if not lib or not locale then
+ return
+ end
+ local pack = {}
+ for _, enum in ipairs(lib) do
+ if enum.enum then
+ pack[enum.enum] = enum
+ end
+ if enum.code then
+ pack[enum.code] = enum
+ end
+ end
+ for _, enum in ipairs(locale) do
+ if pack[enum.enum] then
+ if enum.description then
+ pack[enum.enum].description = enum.description
+ end
+ end
+ if pack[enum.code] then
+ if enum.description then
+ pack[enum.code].description = enum.description
+ end
+ end
+ end
+end
+
+local function mergeField(lib, locale)
+ if not lib or not locale then
+ return
+ end
+ local pack = {}
+ for _, field in ipairs(lib) do
+ if field.field then
+ pack[field.field] = field
+ end
+ end
+ for _, field in ipairs(locale) do
+ if pack[field.field] then
+ if field.description then
+ pack[field.field].description = field.description
+ end
+ end
+ end
+end
+
+local function mergeLocale(libs, locale)
+ if not libs or not locale then
+ return
+ end
+ for name in pairs(locale) do
+ if libs[name] then
+ if locale[name].description then
+ libs[name].description = locale[name].description
+ end
+ mergeEnum(libs[name].enums, locale[name].enums)
+ mergeField(libs[name].fields, locale[name].fields)
+ end
+ end
+end
+
+local function isMatchVersion(version)
+ if not version then
+ return true
+ end
+ local runtimeVersion = config.config.runtime.version
+ if type(version) == 'table' then
+ for i = 1, #version do
+ if version[i] == runtimeVersion then
+ return true
+ end
+ end
+ else
+ if version == runtimeVersion then
+ return true
+ end
+ end
+ return false
+end
+
+local function insertGlobal(tbl, key, value)
+ if not isMatchVersion(value.version) then
+ return false
+ end
+ if not value.doc then
+ value.doc = key
+ end
+ tbl[key] = value
+ return true
+end
+
+local function insertOther(tbl, key, value)
+ if not value.version then
+ return
+ end
+ if not tbl[key] then
+ tbl[key] = {}
+ end
+ if type(value.version) == 'string' then
+ tbl[key][#tbl[key]+1] = value.version
+ elseif type(value.version) == 'table' then
+ for _, version in ipairs(value.version) do
+ if type(version) == 'string' then
+ tbl[key][#tbl[key]+1] = version
+ end
+ end
+ end
+ table.sort(tbl[key])
+end
+
+local function insertCustom(tbl, key, value, libName)
+ if not tbl[key] then
+ tbl[key] = {}
+ end
+ tbl[key][#tbl[key]+1] = libName
+ table.sort(tbl[key])
+end
+
+local function isEnableGlobal(libName)
+ if config.config.runtime.library[libName] then
+ return true
+ end
+ if libName:sub(1, 1) == '@' then
+ return true
+ end
+ return false
+end
+
+local function mergeSource(alllibs, name, lib, libName)
+ if not lib.source then
+ if isEnableGlobal(libName) then
+ local suc = insertGlobal(alllibs.global, name, lib)
+ if not suc then
+ insertOther(alllibs.other, name, lib)
+ end
+ else
+ insertCustom(alllibs.custom, name, lib, libName)
+ end
+ return
+ end
+ for _, source in ipairs(lib.source) do
+ local sourceName = source.name or name
+ if source.type == 'global' then
+ if isEnableGlobal(libName) then
+ local suc = insertGlobal(alllibs.global, sourceName, lib)
+ if not suc then
+ insertOther(alllibs.other, sourceName, lib)
+ end
+ else
+ insertCustom(alllibs.custom, sourceName, lib, libName)
+ end
+ elseif source.type == 'library' then
+ insertGlobal(alllibs.library, sourceName, lib)
+ elseif source.type == 'object' then
+ insertGlobal(alllibs.object, sourceName, lib)
+ end
+ end
+end
+
+local function copy(t)
+ local new = {}
+ for k, v in pairs(t) do
+ new[k] = v
+ end
+ return new
+end
+
+local function insertChild(tbl, name, key, value)
+ if not name or not key then
+ return
+ end
+ if not isMatchVersion(value.version) then
+ return
+ end
+ if not value.doc then
+ value.doc = ('%s.%s'):format(name, key)
+ end
+ if not tbl[name] then
+ tbl[name] = {
+ type = name,
+ name = name,
+ child = {},
+ }
+ end
+ tbl[name].child[key] = copy(value)
+end
+
+local function mergeParent(alllibs, name, lib, libName)
+ for _, parent in ipairs(lib.parent) do
+ if parent.type == 'global' then
+ if isEnableGlobal(libName) then
+ insertChild(alllibs.global, parent.name, name, lib)
+ end
+ elseif parent.type == 'library' then
+ insertChild(alllibs.library, parent.name, name, lib)
+ elseif parent.type == 'object' then
+ insertChild(alllibs.object, parent.name, name, lib)
+ end
+ end
+end
+
+local function mergeLibs(alllibs, libs, libName)
+ if not libs then
+ return
+ end
+ for _, lib in pairs(libs) do
+ if lib.parent then
+ mergeParent(alllibs, lib.name, lib, libName)
+ else
+ mergeSource(alllibs, lib.name, lib, libName)
+ end
+ end
+end
+
+local function loadLocale(language, relative)
+ local localePath = ROOT / 'locale' / language / relative
+ local localeBuf = io.load(localePath)
+ if localeBuf then
+ local locale = table.container()
+ xpcall(lni, log.error, localeBuf, localePath:string(), {locale})
+ return locale
+ end
+ return nil
+end
+
+local function fix(libs)
+ for name, lib in pairs(libs) do
+ lib.name = lib.name or name
+ lib.child = {}
+ end
+end
+
+local function scan(path)
+ local result = {path}
+ local i = 0
+ return function ()
+ i = i + 1
+ local current = result[i]
+ if not current then
+ return nil
+ end
+ if fs.is_directory(current) then
+ for path in current:list_directory() do
+ result[#result+1] = path
+ end
+ end
+ return current
+ end
+end
+
+local function init()
+ local lang = require 'language'
+ local id = lang.id
+ Library.global = table.container()
+ Library.library = table.container()
+ Library.object = table.container()
+ Library.other = table.container()
+ Library.custom = table.container()
+
+ for libPath in (ROOT / 'libs'):list_directory() do
+ local enableGlobal
+ local libName = libPath:filename():string()
+ for path in scan(libPath) do
+ local libs
+ local buf = io.load(path)
+ if buf then
+ libs = table.container()
+ xpcall(lni, log.error, buf, path:string(), {libs})
+ fix(libs)
+ end
+ local relative = fs.relative(path, ROOT)
+
+ local locale = loadLocale('en-US', relative)
+ mergeLocale(libs, locale)
+ if id ~= 'en-US' then
+ locale = loadLocale(id, relative)
+ mergeLocale(libs, locale)
+ end
+ mergeLibs(Library, libs, libName)
+ end
+ end
+end
+
+function Library.reload()
+ init()
+end
+
+init()
+
+return Library
diff --git a/script/core/matchKey.lua b/script/core/matchKey.lua
new file mode 100644
index 00000000..b46250cb
--- /dev/null
+++ b/script/core/matchKey.lua
@@ -0,0 +1,30 @@
+return function (me, other)
+ if me == other then
+ return true
+ end
+ if me == '' then
+ return true
+ end
+ if #me > #other then
+ return false
+ end
+ local lMe = me:lower()
+ local lOther = other:lower()
+ if lMe == lOther:sub(1, #lMe) then
+ return true
+ end
+ local chars = {}
+ for i = 1, #lOther do
+ local c = lOther:sub(i, i)
+ chars[c] = (chars[c] or 0) + 1
+ end
+ for i = 1, #lMe do
+ local c = lMe:sub(i, i)
+ if chars[c] and chars[c] > 0 then
+ chars[c] = chars[c] - 1
+ else
+ return false
+ end
+ end
+ return true
+end
diff --git a/script/core/name.lua b/script/core/name.lua
new file mode 100644
index 00000000..54947974
--- /dev/null
+++ b/script/core/name.lua
@@ -0,0 +1,70 @@
+return function (source, caller)
+ if not source then
+ return ''
+ end
+ local key
+ if source:get 'simple' then
+ local simple = source:get 'simple'
+ local chars = {}
+ for i, obj in ipairs(simple) do
+ if obj.type == 'name' then
+ chars[i] = obj[1]
+ elseif obj.type == 'index' then
+ chars[i] = '[?]'
+ elseif obj.type == 'call' then
+ chars[i] = '(?)'
+ elseif obj.type == ':' then
+ chars[i] = ':'
+ elseif obj.type == '.' then
+ chars[i] = '.'
+ else
+ chars[i] = '*' .. obj.type
+ end
+ if obj == source then
+ break
+ end
+ end
+ key = table.concat(chars)
+ elseif source.type == 'name' then
+ key = source[1]
+ elseif source.type == 'string' then
+ key = ('%q'):format(source[1])
+ elseif source.type == 'number' or source.type == 'boolean' then
+ key = tostring(source[1])
+ elseif source.type == 'simple' then
+ local chars = {}
+ for i, obj in ipairs(source) do
+ if obj.type == 'name' then
+ chars[i] = obj[1]
+ elseif obj.type == 'index' then
+ chars[i] = '[?]'
+ elseif obj.type == 'call' then
+ chars[i] = '(?)'
+ elseif obj.type == ':' then
+ chars[i] = ':'
+ elseif obj.type == '.' then
+ chars[i] = '.'
+ else
+ chars[i] = '*' .. obj.type
+ end
+ end
+ -- 这里有个特殊处理
+ -- function mt:func() 以 mt.func 的形式调用时
+ -- hover 显示为 mt.func(self)
+ if caller then
+ if chars[#chars-1] == ':' then
+ if not caller:get 'object' then
+ chars[#chars-1] = '.'
+ end
+ elseif chars[#chars-1] == '.' then
+ if caller:get 'object' then
+ chars[#chars-1] = ':'
+ end
+ end
+ end
+ key = table.concat(chars)
+ else
+ key = ''
+ end
+ return key
+end
diff --git a/script/core/references.lua b/script/core/references.lua
new file mode 100644
index 00000000..33b38fec
--- /dev/null
+++ b/script/core/references.lua
@@ -0,0 +1,91 @@
+local findSource = require 'core.find_source'
+
+local function parseResult(vm, source, declarat, callback)
+ local isGlobal
+ if source:bindLabel() then
+ source:bindLabel():eachInfo(function (info, src)
+ if (declarat and info.type == 'set') or info.type == 'get' then
+ callback(src)
+ end
+ end)
+ end
+ if source:bindLocal() then
+ local loc = source:bindLocal()
+ callback(loc:getSource())
+ loc:eachInfo(function (info, src)
+ if (declarat and info.type == 'set') or info.type == 'get' then
+ callback(src)
+ end
+ end)
+ loc:getValue():eachInfo(function (info, src)
+ if (declarat and (info.type == 'set' or info.type == 'local' or info.type == 'return')) or info.type == 'get' then
+ callback(src)
+ end
+ end)
+ end
+ if source:bindFunction() then
+ if declarat then
+ callback(source:bindFunction():getSource())
+ end
+ source:bindFunction():eachInfo(function (info, src)
+ if (declarat and (info.type == 'set' or info.type == 'local')) or info.type == 'get' then
+ callback(src)
+ end
+ end)
+ end
+ if source:bindValue() then
+ source:bindValue():eachInfo(function (info, src)
+ if (declarat and (info.type == 'set' or info.type == 'local')) or info.type == 'get' then
+ callback(src)
+ end
+ end)
+ if source:bindValue():isGlobal() then
+ isGlobal = true
+ end
+ end
+ local parent = source:get 'parent'
+ if parent then
+ parent:eachInfo(function (info, src)
+ if info[1] == source[1] then
+ if (declarat and info.type == 'set child') or info.type == 'get child' then
+ callback(src)
+ end
+ end
+ end)
+ end
+ --local emmy = source:getEmmy()
+ --if emmy then
+ -- if emmy.type == 'emmy.class' or emmy.type == 'emmy.type' --then
+--
+ -- end
+ --end
+ return isGlobal
+end
+
+return function (vm, pos, declarat)
+ local source = findSource(vm, pos)
+ if not source then
+ return nil
+ end
+ local positions = {}
+ local mark = {}
+ local isGlobal = parseResult(vm, source, declarat, function (src)
+ if mark[src] then
+ return
+ end
+ mark[src] = true
+ if src.start == 0 then
+ return
+ end
+ local uri = src.uri
+ if uri == '' then
+ uri = nil
+ end
+ positions[#positions+1] = {
+ src.start,
+ src.finish,
+ uri,
+ }
+ end)
+ return positions, isGlobal
+end
diff --git a/script/core/rename.lua b/script/core/rename.lua
new file mode 100644
index 00000000..3a2e8532
--- /dev/null
+++ b/script/core/rename.lua
@@ -0,0 +1,72 @@
+local findSource = require 'core.find_source'
+local parser = require 'parser'
+
+local function parseResult(source, newName)
+ local positions = {}
+ if source:bindLabel() then
+ if not parser:grammar(newName, 'Name') then
+ return nil
+ end
+ source:bindLabel():eachInfo(function (info, src)
+ positions[#positions+1] = { src.start, src.finish, src:getUri() }
+ end)
+ return positions
+ end
+ if source:bindLocal() then
+ local loc = source:bindLocal()
+ if loc:get 'hide' then
+ return nil
+ end
+ if source:get 'in index' then
+ if not parser:grammar(newName, 'Exp') then
+ return positions
+ end
+ else
+ if not parser:grammar(newName, 'Name') then
+ return positions
+ end
+ end
+ local mark = {}
+ loc:eachInfo(function (info, src)
+ if not mark[src] then
+ mark[src] = info
+ positions[#positions+1] = { src.start, src.finish, src:getUri() }
+ end
+ end)
+ return positions
+ end
+ if source:bindValue() and source:get 'parent' then
+ if source:get 'in index' then
+ if not parser:grammar(newName, 'Exp') then
+ return positions
+ end
+ else
+ if not parser:grammar(newName, 'Name') then
+ return positions
+ end
+ end
+ local parent = source:get 'parent'
+ local mark = {}
+ parent:eachInfo(function (info, src)
+ if not mark[src] then
+ mark[src] = info
+ if info.type == 'get child' or info.type == 'set child' then
+ if info[1] == source[1] then
+ positions[#positions+1] = {src.start, src.finish, src:getUri()}
+ end
+ end
+ end
+ end)
+ return positions
+ end
+ return nil
+end
+
+return function (vm, pos, newName)
+ local source = findSource(vm, pos)
+ if not source then
+ return nil
+ end
+ local positions = parseResult(source, newName)
+ return positions
+end
diff --git a/script/core/signature.lua b/script/core/signature.lua
new file mode 100644
index 00000000..bbe35ffa
--- /dev/null
+++ b/script/core/signature.lua
@@ -0,0 +1,133 @@
+local getFunctionHover = require 'core.hover.function'
+local getFunctionHoverAsLib = require 'core.hover.lib_function'
+local getFunctionHoverAsEmmy = require 'core.hover.emmy_function'
+local findLib = require 'core.find_lib'
+local buildValueName = require 'core.hover.name'
+local findSource = require 'core.find_source'
+
+local function findCall(vm, pos)
+ local results = {}
+ vm:eachSource(function (src)
+ if src.type == 'call'
+ and src.start <= pos
+ and src.finish >= pos
+ then
+ results[#results+1] = src
+ end
+ end)
+ if #results == 0 then
+ return nil
+ end
+ -- 可能处于 'func1(func2(' 的嵌套中,将最近的call放到最前面
+ table.sort(results, function (a, b)
+ return a.start > b.start
+ end)
+ return results
+end
+
+local function getSelect(args, pos)
+ if not args then
+ return 1
+ end
+ for i, arg in ipairs(args) do
+ if arg.start <= pos and arg.finish >= pos - 1 then
+ return i
+ end
+ end
+ return #args + 1
+end
+
+local function getFunctionSource(call)
+ local simple = call:get 'simple'
+ for i, source in ipairs(simple) do
+ if source == call then
+ return simple[i-1]
+ end
+ end
+ return nil
+end
+
+local function getHover(call, pos)
+ local args = call:bindCall()
+ if not args then
+ return nil
+ end
+
+ local value = call:findCallFunction()
+ if not value then
+ return nil
+ end
+
+ local select = getSelect(args, pos)
+ local source = getFunctionSource(call)
+ local object = source:get 'object'
+ local lib, fullkey = findLib(source)
+ local name = fullkey or buildValueName(source)
+ local hover
+ if lib then
+ hover = getFunctionHoverAsLib(name, lib, object, select)
+ else
+ local emmy = value:getEmmy()
+ if emmy and emmy.type == 'emmy.functionType' then
+ hover = getFunctionHoverAsEmmy(name, emmy, object, select)
+ else
+ ---@type emmyFunction
+ local func = value:getFunction()
+ hover = getFunctionHover(name, func, object, select)
+ local overLoads = func and func:getEmmyOverLoads()
+ if overLoads then
+ for _, ol in ipairs(overLoads) do
+ hover = getFunctionHoverAsEmmy(name, ol, object, select)
+ end
+ end
+ end
+ end
+ return hover
+end
+
+local function isInFunctionOrTable(call, pos)
+ local args = call:bindCall()
+ if not args then
+ return false
+ end
+ local select = getSelect(args, pos)
+ local arg = args[select]
+ if not arg then
+ return false
+ end
+ if arg.type == 'function' or arg.type == 'table' then
+ return true
+ end
+ return false
+end
+
+return function (vm, pos)
+ local source = findSource(vm, pos) or findSource(vm, pos-1)
+ if not source or source.type == 'string' then
+ return
+ end
+ local calls = findCall(vm, pos)
+ if not calls or #calls == 0 then
+ return nil
+ end
+
+ local nearCall = calls[1]
+ if isInFunctionOrTable(nearCall, pos) then
+ return nil
+ end
+
+ local hover = getHover(nearCall, pos)
+ if not hover then
+ return nil
+ end
+
+ -- skip `name(`
+ local head = #hover.name + 1
+ hover.label = ('%s(%s)'):format(hover.name, hover.argStr)
+ if hover.argLabel then
+ hover.argLabel[1] = hover.argLabel[1] + head
+ hover.argLabel[2] = hover.argLabel[2] + head
+ end
+
+ return { hover }
+end
diff --git a/script/core/snippet.lua b/script/core/snippet.lua
new file mode 100644
index 00000000..7532ce9b
--- /dev/null
+++ b/script/core/snippet.lua
@@ -0,0 +1,64 @@
+local snippet = {}
+
+local function add(cate, key, label)
+ return function (text)
+ if not snippet[cate] then
+ snippet[cate] = {}
+ end
+ if not snippet[cate][key] then
+ snippet[cate][key] = {}
+ end
+ snippet[cate][key][#snippet[cate][key]+1] = {
+ label = label,
+ text = text,
+ }
+ end
+end
+
+add('key', 'do', 'do .. end') [[
+do
+ $0
+end]]
+
+add('key', 'elseif', 'elseif .. then')
+[[elseif ${1:true} then]]
+
+add('key', 'for', 'for .. in') [[
+for ${1:key, value} in ${2:pairs(t)} do
+ $0
+end]]
+
+add('key', 'for', 'for i = ..') [[
+for ${1:i} = ${2:1}, ${3:10, 2} do
+ $0
+end]]
+
+add('key', 'function', 'function ()') [[
+function $1(${2:arg1, arg2, arg3})
+ $0
+end]]
+
+add('key', 'local', 'local function') [[
+local function ${1:name}(${2:arg1, arg2, arg3})
+ $0
+end]]
+
+add('key', 'if', 'if .. then') [[
+if ${1:true} then
+ $0
+end]]
+
+add('key', 'repeat', 'repeat .. until') [[
+repeat
+ $0
+until ${1:true}]]
+
+add('key', 'while', 'while .. do') [[
+while ${1:true} do
+ $0
+end]]
+
+add('key', 'return', 'do return end')
+[[do return ${1:true} end]]
+
+return snippet