diff options
Diffstat (limited to 'script/core')
-rw-r--r-- | script/core/code_action.lua | 410 | ||||
-rw-r--r-- | script/core/completion.lua | 1079 | ||||
-rw-r--r-- | script/core/definition.lua | 296 | ||||
-rw-r--r-- | script/core/diagnostics.lua | 1042 | ||||
-rw-r--r-- | script/core/document_symbol.lua | 260 | ||||
-rw-r--r-- | script/core/find_lib.lua | 65 | ||||
-rw-r--r-- | script/core/find_source.lua | 59 | ||||
-rw-r--r-- | script/core/folding_range.lua | 73 | ||||
-rw-r--r-- | script/core/global.lua | 49 | ||||
-rw-r--r-- | script/core/highlight.lua | 54 | ||||
-rw-r--r-- | script/core/hover/emmy_function.lua | 143 | ||||
-rw-r--r-- | script/core/hover/function.lua | 243 | ||||
-rw-r--r-- | script/core/hover/hover.lua | 326 | ||||
-rw-r--r-- | script/core/hover/init.lua | 1 | ||||
-rw-r--r-- | script/core/hover/lib_function.lua | 222 | ||||
-rw-r--r-- | script/core/hover/name.lua | 38 | ||||
-rw-r--r-- | script/core/implementation.lua | 204 | ||||
-rw-r--r-- | script/core/init.lua | 19 | ||||
-rw-r--r-- | script/core/library.lua | 296 | ||||
-rw-r--r-- | script/core/matchKey.lua | 30 | ||||
-rw-r--r-- | script/core/name.lua | 70 | ||||
-rw-r--r-- | script/core/references.lua | 91 | ||||
-rw-r--r-- | script/core/rename.lua | 72 | ||||
-rw-r--r-- | script/core/signature.lua | 133 | ||||
-rw-r--r-- | script/core/snippet.lua | 64 |
25 files changed, 5339 insertions, 0 deletions
diff --git a/script/core/code_action.lua b/script/core/code_action.lua new file mode 100644 index 00000000..2c1fb14d --- /dev/null +++ b/script/core/code_action.lua @@ -0,0 +1,410 @@ +local lang = require 'language' +local library = require 'core.library' + +local function disableDiagnostic(lsp, uri, data, callback) + callback { + title = lang.script('ACTION_DISABLE_DIAG', data.code), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_DISABLE_DIAG, + command = 'config', + arguments = { + { + key = {'diagnostics', 'disable'}, + action = 'add', + value = data.code, + } + } + } + } +end + +local function addGlobal(name, callback) + callback { + title = lang.script('ACTION_MARK_GLOBAL', name), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_MARK_GLOBAL, + command = 'config', + arguments = { + { + key = {'diagnostics', 'globals'}, + action = 'add', + value = name, + } + } + }, + } +end + +local function changeVersion(version, callback) + callback { + title = lang.script('ACTION_RUNTIME_VERSION', version), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_RUNTIME_VERSION, + command = 'config', + arguments = { + { + key = {'runtime', 'version'}, + action = 'set', + value = version, + } + } + }, + } +end + +local function openCustomLibrary(libName, callback) + callback { + title = lang.script('ACTION_OPEN_LIBRARY', libName), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_OPEN_LIBRARY, + command = 'config', + arguments = { + { + key = {'runtime', 'library'}, + action = 'add', + value = libName, + } + } + }, + } +end + +local function solveUndefinedGlobal(lsp, uri, data, callback) + local vm, lines, text = lsp:getVM(uri) + if not vm then + return + end + local start = lines:position(data.range.start.line + 1, data.range.start.character + 1) + local finish = lines:position(data.range['end'].line + 1, data.range['end'].character) + local name = text:sub(start, finish) + if #name < 0 or name:find('[^%w_]') then + return + end + addGlobal(name, callback) + local otherVersion = library.other[name] + if otherVersion then + for _, version in ipairs(otherVersion) do + changeVersion(version, callback) + end + end + + local customLibrary = library.custom[name] + if customLibrary then + for _, libName in ipairs(customLibrary) do + openCustomLibrary(libName, callback) + end + end +end + +local function solveLowercaseGlobal(lsp, uri, data, callback) + local vm, lines, text = lsp:getVM(uri) + if not vm then + return + end + local start = lines:position(data.range.start.line + 1, data.range.start.character + 1) + local finish = lines:position(data.range['end'].line + 1, data.range['end'].character) + local name = text:sub(start, finish) + if #name < 0 or name:find('[^%w_]') then + return + end + addGlobal(name, callback) +end + +local function solveTrailingSpace(lsp, uri, data, callback) + callback { + title = lang.script.ACTION_REMOVE_SPACE, + kind = 'quickfix', + command = { + title = lang.script.COMMAND_REMOVE_SPACE, + command = 'removeSpace', + arguments = { + { + uri = uri, + } + } + }, + } +end + +local function solveNewlineCall(lsp, uri, data, callback) + callback { + title = lang.script.ACTION_ADD_SEMICOLON, + kind = 'quickfix', + edit = { + changes = { + [uri] = { + { + range = { + start = data.range.start, + ['end'] = data.range.start, + }, + newText = ';', + } + } + } + } + } +end + +local function solveAmbiguity1(lsp, uri, data, callback) + callback { + title = lang.script.ACTION_ADD_BRACKETS, + kind = 'quickfix', + command = { + title = lang.script.COMMAND_ADD_BRACKETS, + command = 'solve', + arguments = { + { + name = 'ambiguity-1', + uri = uri, + range = data.range, + } + } + }, + } +end + +local function findSyntax(astErr, lines, data) + local start = lines:position(data.range.start.line + 1, data.range.start.character + 1) + local finish = lines:position(data.range['end'].line + 1, data.range['end'].character) + for _, err in ipairs(astErr) do + if err.start == start and err.finish == finish then + return err + end + end + return nil +end + +local function solveSyntaxByChangeVersion(err, callback) + if type(err.version) == 'table' then + for _, version in ipairs(err.version) do + changeVersion(version, callback) + end + else + changeVersion(err.version, callback) + end +end + +local function solveSyntaxByAddDoEnd(uri, data, callback) + callback { + title = lang.script.ACTION_ADD_DO_END, + kind = 'quickfix', + edit = { + changes = { + [uri] = { + { + range = { + start = data.range.start, + ['end'] = data.range.start, + }, + newText = 'do ', + }, + { + range = { + start = data.range['end'], + ['end'] = data.range['end'], + }, + newText = ' end', + } + } + } + } + } +end + +local function solveSyntaxByFix(uri, err, lines, callback) + local changes = {} + for _, e in ipairs(err.fix) do + local start_row, start_col = lines:rowcol(e.start) + local finish_row, finish_col = lines:rowcol(e.finish) + changes[#changes+1] = { + range = { + start = { + line = start_row - 1, + character = start_col - 1, + }, + ['end'] = { + line = finish_row - 1, + character = finish_col, + }, + }, + newText = e.text, + } + end + callback { + title = lang.script['ACTION_' .. err.fix.title], + kind = 'quickfix', + edit = { + changes = { + [uri] = changes, + } + } + } +end + +local function findEndPosition(lines, row, endrow) + if endrow == row then + return { + newText = ' end', + range = { + start = { + line = row - 1, + character = 999999, + }, + ['end'] = { + line = row - 1, + character = 999999, + } + } + } + else + local l = lines[row] + return { + newText = ('\t'):rep(l.tab) .. (' '):rep(l.sp) .. 'end\n', + range = { + start = { + line = endrow, + character = 0, + }, + ['end'] = { + line = endrow, + character = 0, + } + } + } + end +end + +local function isIfPart(id, lines, i) + if id ~= 'if' then + return false + end + local buf = lines:line(i) + local first = buf:match '^[%s\t]*([%w]+)' + if first == 'else' or first == 'elseif' then + return true + end + return false +end + +local function solveSyntaxByAddEnd(uri, start, finish, lines, callback) + local row = lines:rowcol(start) + local line = lines[row] + if not line then + return nil + end + local id = lines.buf:sub(start, finish) + local sp = line.sp + line.tab * 4 + for i = row + 1, #lines do + local nl = lines[i] + local lsp = nl.sp + nl.tab * 4 + if lsp <= sp and not isIfPart(id, lines, i) then + callback { + title = lang.script['ACTION_ADD_END'], + kind = 'quickfix', + edit = { + changes = { + [uri] = { + findEndPosition(lines, row, i - 1) + } + } + } + } + return + end + end + return nil +end + +---@param lsp LSP +---@param uri uri +---@param data table +---@param callback function +local function solveSyntax(lsp, uri, data, callback) + local file = lsp:getFile(uri) + if not file then + return + end + local astErr, lines = file:getAstErr(), file:getLines() + if not astErr or not lines then + return + end + local err = findSyntax(astErr, lines, data) + if not err then + return nil + end + if err.version then + solveSyntaxByChangeVersion(err, callback) + end + if err.type == 'ACTION_AFTER_BREAK' or err.type == 'ACTION_AFTER_RETURN' then + solveSyntaxByAddDoEnd(uri, data, callback) + end + if err.type == 'MISS_END' then + solveSyntaxByAddEnd(uri, err.start, err.finish, lines, callback) + end + if err.type == 'MISS_SYMBOL' and err.info.symbol == 'end' then + solveSyntaxByAddEnd(uri, err.info.related[1], err.info.related[2], lines, callback) + end + if err.fix then + solveSyntaxByFix(uri, err, lines, callback) + end +end + +local function solveDiagnostic(lsp, uri, data, callback) + if data.source == lang.script.DIAG_SYNTAX_CHECK then + solveSyntax(lsp, uri, data, callback) + end + if not data.code then + return + end + if data.code == 'undefined-global' then + solveUndefinedGlobal(lsp, uri, data, callback) + end + if data.code == 'trailing-space' then + solveTrailingSpace(lsp, uri, data, callback) + end + if data.code == 'newline-call' then + solveNewlineCall(lsp, uri, data, callback) + end + if data.code == 'ambiguity-1' then + solveAmbiguity1(lsp, uri, data, callback) + end + if data.code == 'lowercase-global' then + solveLowercaseGlobal(lsp, uri, data, callback) + end + disableDiagnostic(lsp, uri, data, callback) +end + +local function rangeContain(a, b) + if a.start.line > b.start.line then + return false + end + if a.start.character > b.start.character then + return false + end + if a['end'].line < b['end'].line then + return false + end + if a['end'].character < b['end'].character then + return false + end + return true +end + +return function (lsp, uri, diagnostics, range) + local results = {} + + for _, data in ipairs(diagnostics) do + if rangeContain(data.range, range) then + solveDiagnostic(lsp, uri, data, function (result) + results[#results+1] = result + end) + end + end + + return results +end diff --git a/script/core/completion.lua b/script/core/completion.lua new file mode 100644 index 00000000..756f136b --- /dev/null +++ b/script/core/completion.lua @@ -0,0 +1,1079 @@ +local findSource = require 'core.find_source' +local getFunctionHover = require 'core.hover.function' +local getFunctionHoverAsLib = require 'core.hover.lib_function' +local getFunctionHoverAsEmmy = require 'core.hover.emmy_function' +local sourceMgr = require 'vm.source' +local config = require 'config' +local matchKey = require 'core.matchKey' +local parser = require 'parser' +local lang = require 'language' +local snippet = require 'core.snippet' +local State + +local CompletionItemKind = { + Text = 1, + Method = 2, + Function = 3, + Constructor = 4, + Field = 5, + Variable = 6, + Class = 7, + Interface = 8, + Module = 9, + Property = 10, + Unit = 11, + Value = 12, + Enum = 13, + Keyword = 14, + Snippet = 15, + Color = 16, + File = 17, + Reference = 18, + Folder = 19, + EnumMember = 20, + Constant = 21, + Struct = 22, + Event = 23, + Operator = 24, + TypeParameter = 25, +} + +local KEYS = {'and', 'break', 'do', 'else', 'elseif', 'end', 'false', 'for', 'function', 'goto', 'if', 'in', 'local', 'nil', 'not', 'or', 'repeat', 'return', 'then', 'true', 'until', 'while'} +local KEYMAP = {} +for _, k in ipairs(KEYS) do + KEYMAP[k] = true +end + +local EMMY_KEYWORD = {'class', 'type', 'alias', 'param', 'return', 'field', 'generic', 'vararg', 'language', 'see', 'overload'} + +local function getDucumentation(name, value) + if value:getType() == 'function' then + local lib = value:getLib() + local hover + if lib then + hover = getFunctionHoverAsLib(name, lib) + else + local emmy = value:getEmmy() + if emmy and emmy.type == 'emmy.functionType' then + hover = getFunctionHoverAsEmmy(name, emmy) + else + hover = getFunctionHover(name, value:getFunction()) + end + end + if not hover then + return nil + end + local text = ([[ +```lua +%s +``` +%s +```lua +%s +``` +%s +]]):format(hover.label or '', hover.description or '', hover.enum or '', hover.doc or '') + return { + kind = 'markdown', + value = text, + } + end + local lib = value:getLib() + if lib then + return { + kind = 'markdown', + value = lib.description, + } + end + local comment = value:getComment() + if comment then + return { + kind = 'markdown', + value = comment, + } + end + return nil +end + +local function getDetail(value) + local literal = value:getLiteral() + local tp = type(literal) + local detals = {} + if value:getType() ~= 'any' then + detals[#detals+1] = ('(%s)'):format(value:getType()) + end + if tp == 'boolean' then + detals[#detals+1] = (' = %q'):format(literal) + elseif tp == 'string' then + detals[#detals+1] = (' = %q'):format(literal) + elseif tp == 'number' then + if math.type(literal) == 'integer' then + detals[#detals+1] = (' = %q'):format(literal) + else + local str = (' = %.16f'):format(literal) + local dot = str:find('.', 1, true) + local suffix = str:find('[0]+$', dot + 2) + if suffix then + detals[#detals+1] = str:sub(1, suffix - 1) + else + detals[#detals+1] = str + end + end + end + if value:getType() == 'function' then + ---@type emmyFunction + local func = value:getFunction() + local overLoads = func and func:getEmmyOverLoads() + if overLoads then + detals[#detals+1] = lang.script('HOVER_MULTI_PROTOTYPE', #overLoads + 1) + end + end + if #detals == 0 then + return nil + end + return table.concat(detals) +end + +local function getKind(cata, value) + if value:getType() == 'function' then + local func = value:getFunction() + if func and func:getObject() then + return CompletionItemKind.Method + else + return CompletionItemKind.Function + end + end + if cata == 'field' then + local literal = value:getLiteral() + local tp = type(literal) + if tp == 'number' or tp == 'integer' or tp == 'string' then + return CompletionItemKind.Enum + end + end + return nil +end + +local function buildSnipArgs(args, enums) + local t = {} + for i, arg in ipairs(args) do + local name = arg:match '^[^:]+' + local enum = enums and enums[name] + if enum and #enum > 0 then + t[i] = ('${%d|%s|}'):format(i, table.concat(enum, ',')) + else + t[i] = ('${%d:%s}'):format(i, arg) + end + end + return table.concat(t, ', ') +end + +local function getFunctionSnip(name, value, source) + if value:getType() ~= 'function' then + return + end + local lib = value:getLib() + local object = source:get 'object' + local hover + if lib then + hover = getFunctionHoverAsLib(name, lib, object) + else + local emmy = value:getEmmy() + if emmy and emmy.type == 'emmy.functionType' then + hover = getFunctionHoverAsEmmy(name, emmy, object) + else + hover = getFunctionHover(name, value:getFunction(), object) + end + end + if not hover then + return ('%s()'):format(name) + end + if not hover.args then + return ('%s()'):format(name) + end + return ('%s(%s)'):format(name, buildSnipArgs(hover.args, hover.rawEnum)) +end + +local function getValueData(cata, name, value, pos, source) + local data = { + documentation = getDucumentation(name, value), + detail = getDetail(value), + kind = getKind(cata, value), + snip = getFunctionSnip(name, value, source), + } + if cata == 'field' then + if not parser:grammar(name, 'Name') then + if source:get 'simple' and source:get 'simple' [1] ~= source then + data.textEdit = { + start = pos + 1, + finish = pos, + newText = ('[%q]'):format(name), + } + data.additionalTextEdits = { + { + start = pos, + finish = pos, + newText = '', + } + } + else + data.textEdit = { + start = pos + 1, + finish = pos, + newText = ('_ENV[%q]'):format(name), + } + data.additionalTextEdits = { + { + start = pos, + finish = pos, + newText = '', + } + } + end + end + end + return data +end + +local function searchLocals(vm, source, word, callback, pos) + vm:eachSource(function (src) + local loc = src:bindLocal() + if not loc then + return + end + + if src.start <= source.start + and loc:close() >= source.finish + and matchKey(word, loc:getName()) + then + callback(loc:getName(), src, CompletionItemKind.Variable, getValueData('local', loc:getName(), loc:getValue(), pos, source)) + end + end) +end + +local function sortPairs(t) + local keys = {} + for k in pairs(t) do + keys[#keys+1] = k + end + table.sort(keys) + local i = 0 + return function () + i = i + 1 + local k = keys[i] + return k, t[k] + end +end + +local function searchFieldsByInfo(parent, word, source, map) + parent:eachInfo(function (info, src) + local k = info[1] + if src == source then + return + end + if map[k] then + return + end + if KEYMAP[k] then + return + end + if info.type ~= 'set child' and info.type ~= 'get child' then + return + end + if type(k) ~= 'string' then + return + end + local v = parent:getChild(k) + if not v then + return + end + if source:get 'object' and v:getType() ~= 'function' then + return + end + if matchKey(word, k) then + map[k] = v + end + end) +end + +local function searchFieldsByChild(parent, word, source, map) + parent:eachChild(function (k, v) + if map[k] then + return + end + if KEYMAP[k] then + return + end + if not v:getLib() then + return + end + if type(k) ~= 'string' then + return + end + if source:get 'object' and v:getType() ~= 'function' then + return + end + if matchKey(word, k) then + map[k] = v + end + end) +end + +---@param vm VM +local function searchFields(vm, source, word, callback, pos) + local parent = source:get 'parent' or vm.env:getValue() + if not parent then + return + end + local map = {} + local current = parent + for _ = 1, 3 do + searchFieldsByInfo(current, word, source, map) + current = current:getMetaMethod('__index') + if not current then + break + end + end + searchFieldsByChild(parent, word, source, map) + for k, v in sortPairs(map) do + callback(k, nil, CompletionItemKind.Field, getValueData('field', k, v, pos, source)) + end +end + +local function searchIndex(vm, source, word, callback) + vm:eachSource(function (src) + if src:get 'table index' then + if matchKey(word, src[1]) then + callback(src[1], src, CompletionItemKind.Property) + end + end + end) +end + +local function searchCloseGlobal(vm, start, finish, word, callback) + vm:eachSource(function (src) + if (src:get 'global' or src:bindLocal()) + and src.start >= start + and src.finish <= finish + then + if matchKey(word, src[1]) then + callback(src[1], src, CompletionItemKind.Variable) + end + end + end) +end + +local function searchParams(vm, source, func, word, callback) + if not func then + return + end + ---@type emmyFunction + local emmyParams = func:getEmmyParams() + if not emmyParams then + return + end + if #emmyParams > 1 then + if not func.args + or not func.args[1] + or func.args[1]:getSource() == source then + if matchKey(word, source and source[1] or '') then + local names = {} + for _, param in ipairs(emmyParams) do + local name = param:getName() + names[#names+1] = name + end + callback(table.concat(names, ', '), nil, CompletionItemKind.Snippet) + end + end + end + for _, param in ipairs(emmyParams) do + local name = param:getName() + if matchKey(word, name) then + callback(name, param:getSource(), CompletionItemKind.Interface) + end + end +end + +local function searchKeyWords(vm, source, word, callback) + local snipType = config.config.completion.keywordSnippet + for _, key in ipairs(KEYS) do + if matchKey(word, key) then + if snippet.key[key] then + if snipType ~= 'Replace' + or key == 'local' + or key == 'return' then + callback(key, nil, CompletionItemKind.Keyword) + end + if snipType ~= 'Disable' then + for _, data in ipairs(snippet.key[key]) do + callback(data.label, nil, CompletionItemKind.Snippet, { + insertText = data.text, + }) + end + end + else + callback(key, nil, CompletionItemKind.Keyword) + end + end + end +end + +local function searchGlobals(vm, source, word, callback, pos) + local global = vm.env:getValue() + local map = {} + local current = global + for _ = 1, 3 do + searchFieldsByInfo(current, word, source, map) + current = current:getMetaMethod('__index') + if not current then + break + end + end + searchFieldsByChild(global, word, source, map) + for k, v in sortPairs(map) do + callback(k, nil, CompletionItemKind.Field, getValueData('field', k, v, pos, source)) + end +end + +local function searchAsGlobal(vm, source, word, callback, pos) + if word == '' then + return + end + searchLocals(vm, source, word, callback, pos) + searchFields(vm, source, word, callback, pos) + searchKeyWords(vm, source, word, callback) +end + +local function searchAsKeyowrd(vm, source, word, callback, pos) + searchLocals(vm, source, word, callback, pos) + searchGlobals(vm, source, word, callback, pos) + searchKeyWords(vm, source, word, callback) +end + +local function searchAsSuffix(vm, source, word, callback, pos) + searchFields(vm, source, word, callback, pos) +end + +local function searchAsIndex(vm, source, word, callback, pos) + searchLocals(vm, source, word, callback, pos) + searchIndex(vm, source, word, callback) + searchFields(vm, source, word, callback, pos) +end + +local function searchAsLocal(vm, source, word, callback) + local loc = source:bindLocal() + if not loc then + return + end + local close = loc:close() + -- 因为闭包的关系落在局部变量finish到close范围内的全局变量一定能访问到该局部变量 + searchCloseGlobal(vm, source.finish, close, word, callback) + -- 特殊支持 local function + if matchKey(word, 'function') then + callback('function', nil, CompletionItemKind.Keyword) + -- TODO 需要有更优美的实现方式 + local data = snippet.key['function'][1] + callback(data.label, nil, CompletionItemKind.Snippet, { + insertText = data.text, + }) + end +end + +local function searchAsArg(vm, source, word, callback) + searchParams(vm, source, source:get 'arg', word, callback) + + local loc = source:bindLocal() + if loc then + local close = loc:close() + -- 因为闭包的关系落在局部变量finish到close范围内的全局变量一定能访问到该局部变量 + searchCloseGlobal(vm, source.finish, close, word, callback) + return + end +end + +local function searchFunction(vm, source, word, pos, callback) + if pos >= source.argStart and pos <= source.argFinish then + searchParams(vm, nil, source:bindFunction():getFunction(), word, callback) + searchCloseGlobal(vm, source.argFinish, source.finish, word, callback) + end +end + +local function searchEmmyKeyword(vm, source, word, callback) + for _, kw in ipairs(EMMY_KEYWORD) do + if matchKey(word, kw) then + callback(kw, nil, CompletionItemKind.Keyword) + end + end +end + +local function searchEmmyClass(vm, source, word, callback) + local classes = {} + vm.emmyMgr:eachClass(function (class) + if class.type == 'emmy.class' or class.type == 'emmy.alias' then + if matchKey(word, class:getName()) then + classes[#classes+1] = class + end + end + end) + table.sort(classes, function (a, b) + return a:getName() < b:getName() + end) + for _, class in ipairs(classes) do + callback(class:getName(), class:getSource(), CompletionItemKind.Class) + end +end + +local function searchEmmyFunctionParam(vm, source, word, callback) + local func = source:get 'emmy function' + if not func.args then + return + end + if #func.args > 1 and matchKey(word, func.args[1].name) then + local list = {} + local args = {} + for i, arg in ipairs(func.args) do + if func:getObject() and i == 1 then + goto NEXT + end + args[#args+1] = arg.name + if #list == 0 then + list[#list+1] = ('%s any'):format(arg.name) + else + list[#list+1] = ('---@param %s any'):format(arg.name) + end + :: NEXT :: + end + callback(('%s'):format(table.concat(args, ', ')), nil, CompletionItemKind.Snippet, { + insertText = table.concat(list, '\n') + }) + end + for i, arg in ipairs(func.args) do + if func:getObject() and i == 1 then + goto NEXT + end + if matchKey(word, arg.name) then + callback(arg.name, nil, CompletionItemKind.Interface) + end + :: NEXT :: + end +end + +local function searchSource(vm, source, word, callback, pos) + if source.type == 'keyword' then + searchAsKeyowrd(vm, source, word, callback, pos) + return + end + if source:get 'table index' then + searchAsIndex(vm, source, word, callback, pos) + return + end + if source:get 'arg' then + searchAsArg(vm, source, word, callback) + return + end + if source:get 'global' then + searchAsGlobal(vm, source, word, callback, pos) + return + end + if source:action() == 'local' then + searchAsLocal(vm, source, word, callback) + return + end + if source:bindLocal() then + searchAsGlobal(vm, source, word, callback, pos) + return + end + if source:get 'simple' + and (source.type == 'name' or source.type == '.' or source.type == ':') then + searchAsSuffix(vm, source, word, callback, pos) + return + end + if source:bindFunction() then + searchFunction(vm, source, word, pos, callback) + return + end + if source.type == 'emmyIncomplete' then + searchEmmyKeyword(vm, source, word, callback) + State.ignoreText = true + return + end + if source:get 'emmy class' then + searchEmmyClass(vm, source, word, callback) + State.ignoreText = true + return + end + if source:get 'emmy function' then + searchEmmyFunctionParam(vm, source, word, callback) + State.ignoreText = true + return + end +end + +local function buildTextEdit(start, finish, str, quo) + local text, lquo, rquo, label, filterText + if quo == nil then + local text = str:gsub('\r', '\\r'):gsub('\n', '\\n'):gsub('"', '\\"') + return { + label = '"' .. text .. '"' + } + end + if quo == '"' then + label = str + filterText = str + text = str:gsub('\r', '\\r'):gsub('\n', '\\n'):gsub('"', '\\"') + lquo = quo + rquo = quo + elseif quo == "'" then + label = str + filterText = str + text = str:gsub('\r', '\\r'):gsub('\n', '\\n'):gsub("'", "\\'") + lquo = quo + rquo = quo + else + label = str + filterText = str + lquo = quo + rquo = ']' .. lquo:sub(2, -2) .. ']' + while str:find(rquo, 1, true) do + lquo = '[=' .. quo:sub(2) + rquo = ']' .. lquo:sub(2, -2) .. ']' + end + text = str + end + return { + label = label, + filterText = filterText, + textEdit = { + start = start + #quo, + finish = finish - #quo, + newText = text, + }, + additionalTextEdits = { + { + start = start, + finish = start + #quo - 1, + newText = lquo, + }, + { + start = finish - #quo + 1, + finish = finish, + newText = rquo, + }, + } + } +end + +local function searchInRequire(vm, source, callback) + if not vm.lsp or not vm.lsp.workspace then + return + end + if source.type ~= 'string' then + return + end + local list, map = vm.lsp.workspace:matchPath(vm.uri, source[1]) + if not list then + return + end + for _, str in ipairs(list) do + local data = buildTextEdit(source.start, source.finish, str, source[2]) + data.documentation = map[str] + callback(str, nil, CompletionItemKind.Reference, data) + end +end + +local function searchEnumAsLib(vm, source, word, callback, pos, args, lib) + local select = #args + 1 + for i, arg in ipairs(args) do + if arg.start <= pos and arg.finish >= pos then + select = i + break + end + end + + -- 根据参数位置找枚举值 + if lib.args and lib.enums then + local arg = lib.args[select] + local name = arg and arg.name + for _, enum in ipairs(lib.enums) do + if enum.name and enum.name == name and enum.enum then + if matchKey(word, enum.enum) then + local strSource = parser:parse(tostring(enum.enum), 'String') + if strSource then + if source.type == 'string' then + local data = buildTextEdit(source.start, source.finish, strSource[1], source[2]) + data.documentation = enum.description + callback(enum.enum, nil, CompletionItemKind.EnumMember, data) + else + callback(enum.enum, nil, CompletionItemKind.EnumMember, { + documentation = enum.description + }) + end + end + else + callback(enum.enum, nil, CompletionItemKind.EnumMember, { + documentation = enum.description + }) + end + end + end + end + + -- 搜索特殊函数 + if lib.special == 'require' then + if select == 1 then + searchInRequire(vm, source, callback) + end + end +end + +local function buildEmmyEnumComment(enum, data) + if not enum.comment then + return data + end + if not data then + data = {} + end + data.documentation = tostring(enum.comment) + return data +end + +local function searchEnumAsEmmyParams(vm, source, word, callback, pos, args, func) + local select = #args + 1 + for i, arg in ipairs(args) do + if arg.start <= pos and arg.finish >= pos then + select = i + break + end + end + + local param = func:findEmmyParamByIndex(select) + if not param then + return + end + + param:eachEnum(function (enum) + local str = enum[1] + if matchKey(word, str) then + local strSource = parser:parse(tostring(str), 'String') + if strSource then + if source.type == 'string' then + local data = buildTextEdit(source.start, source.finish, strSource[1], source[2]) + callback(str, nil, CompletionItemKind.EnumMember, buildEmmyEnumComment(enum, data)) + else + callback(str, nil, CompletionItemKind.EnumMember, buildEmmyEnumComment(enum)) + end + else + callback(str, nil, CompletionItemKind.EnumMember, buildEmmyEnumComment(enum)) + end + end + end) + + local option = param:getOption() + if option and option.special == 'require:1' then + searchInRequire(vm, source, callback) + end +end + +local function getSelect(args, pos) + if not args then + return 1 + end + for i, arg in ipairs(args) do + if arg.start <= pos and arg.finish >= pos - 1 then + return i + end + end + return #args + 1 +end + +local function isInFunctionOrTable(call, pos) + local args = call:bindCall() + if not args then + return false + end + local select = getSelect(args, pos) + local arg = args[select] + if not arg then + return false + end + if arg.type == 'function' or arg.type == 'table' then + return true + end + return false +end + +local function searchCallArg(vm, source, word, callback, pos) + local results = {} + vm:eachSource(function (src) + if src.type == 'call' + and src.start <= pos + and src.finish >= pos + then + results[#results+1] = src + end + end) + if #results == 0 then + return nil + end + -- 可能处于 'func1(func2(' 的嵌套中,将最近的call放到最前面 + table.sort(results, function (a, b) + return a.start > b.start + end) + local call = results[1] + if isInFunctionOrTable(call, pos) then + return + end + + local args = call:bindCall() + if not args then + return + end + + local value = call:findCallFunction() + if not value then + return + end + + local lib = value:getLib() + if lib then + searchEnumAsLib(vm, source, word, callback, pos, args, lib) + return + end + + ---@type emmyFunction + local func = value:getFunction() + if func then + searchEnumAsEmmyParams(vm, source, word, callback, pos, args, func) + return + end +end + +local function searchAllWords(vm, source, word, callback, pos) + if word == '' then + return + end + if source.type == 'string' then + return + end + vm:eachSource(function (src) + if src.type == 'name' + and not (src.start <= pos and src.finish >= pos) + and matchKey(word, src[1]) + then + callback(src[1], src, CompletionItemKind.Text) + end + end) +end + +local function searchSpecialHashSign(vm, pos, text, callback) + -- 尝试 XXX[#XXX+1] + -- 1. 搜索 [] + local index + vm:eachSource(function (src) + if src.type == 'index' + and src.start <= pos + and src.finish >= pos + then + index = src + return true + end + end) + if not index then + return nil + end + -- 2. [] 内部只能有一个 # + local inside = index[1] + if not inside then + return nil + end + if inside.op ~= '#' then + return nil + end + -- 3. [] 左侧必须是 simple ,且index 是 simple 的最后一项 + local simple = index:get 'simple' + if not simple then + return nil + end + if simple[#simple] ~= index then + return nil + end + local chars = text:sub(simple.start, simple[#simple-1].finish) + -- 4. 创建代码片段 + if simple:get 'as action' then + local label = chars .. '+1' + callback(label, nil, CompletionItemKind.Snippet, { + textEdit = { + start = inside.start + 1, + finish = index.finish, + newText = ('%s] = '):format(label), + }, + }) + else + local label = chars + callback(label, nil, CompletionItemKind.Snippet, { + textEdit = { + start = inside.start + 1, + finish = index.finish, + newText = ('%s]'):format(label), + }, + }) + end +end + +local function searchSpecial(vm, source, word, callback, pos, text) + searchSpecialHashSign(vm, pos, text, callback) +end + +local function makeList(source, pos, word) + local list = {} + local mark = {} + return function (name, src, kind, data) + if src == source then + return + end + if word == name then + if src and src.start <= pos and src.finish >= pos then + return + end + end + if mark[name] then + return + end + mark[name] = true + if not data then + data = {} + end + if not data.label then + data.label = name + end + if not data.kind then + data.kind = kind + end + list[#list+1] = data + if data.snip then + local snipType = config.config.completion.callSnippet + if snipType ~= 'Disable' then + local snipData = table.deepCopy(data) + snipData.insertText = data.snip + snipData.kind = CompletionItemKind.Snippet + snipData.label = snipData.label .. '()' + snipData.snip = nil + if snipType == 'Both' then + list[#list+1] = snipData + elseif snipType == 'Replace' then + list[#list] = snipData + end + end + data.snip = nil + end + end, list +end + +local function keywordSource(vm, word, pos) + if not KEYMAP[word] then + return nil + end + return vm:instantSource { + type = 'keyword', + start = pos, + finish = pos + #word - 1, + [1] = word, + } +end + +local function findStartPos(pos, buf) + local res = nil + for i = pos, 1, -1 do + local c = buf:sub(i, i) + if c:find '[%w_]' then + res = i + else + break + end + end + if not res then + for i = pos, 1, -1 do + local c = buf:sub(i, i) + if c == '.' or c == ':' or c == '|' or c == '(' then + res = i + break + elseif c == '#' or c == '@' then + res = i + 1 + break + elseif c:find '[%s%c]' then + else + break + end + end + end + if not res then + return pos + end + return res +end + +local function findWord(position, text) + local word = text + for i = position, 1, -1 do + local c = text:sub(i, i) + if not c:find '[%w_]' then + word = text:sub(i+1, position) + break + end + end + return word:match('^([%w_]*)') +end + +local function getSource(vm, pos, text, filter) + local word = findWord(pos, text) + local source = keywordSource(vm, word, pos) + if source then + return source, pos, word + end + source = findSource(vm, pos, filter) + if source then + return source, pos, word + end + pos = findStartPos(pos, text) + source = findSource(vm, pos, filter) or findSource(vm, pos-1, filter) + return source, pos, word +end + +return function (vm, text, pos, oldText) + local filter = { + ['name'] = true, + ['string'] = true, + ['.'] = true, + [':'] = true, + ['emmyName'] = true, + ['emmyIncomplete'] = true, + ['call'] = true, + ['function'] = true, + ['localfunction'] = true, + } + local source, pos, word = getSource(vm, pos, text, filter) + if not source then + return nil + end + if oldText then + local oldWord = oldText:sub(source.start, source.finish) + if word:sub(1, #oldWord) ~= oldWord then + return nil + end + end + State = {} + local callback, list = makeList(source, pos, word) + searchSpecial(vm, source, word, callback, pos, text) + searchCallArg(vm, source, word, callback, pos) + searchSource(vm, source, word, callback, pos) + if not oldText or #list > 0 then + if not State.ignoreText then + searchAllWords(vm, source, word, callback, pos) + end + end + + if #list == 0 then + return nil + end + + return list +end diff --git a/script/core/definition.lua b/script/core/definition.lua new file mode 100644 index 00000000..8680a29b --- /dev/null +++ b/script/core/definition.lua @@ -0,0 +1,296 @@ +local findSource = require 'core.find_source' +local Mode + +local function parseValueSimily(callback, vm, source) + local key = source[1] + if not key then + return nil + end + vm:eachSource(function (other) + if other == source then + goto CONTINUE + end + if other[1] == key + and not other:bindLocal() + and other:bindValue() + and source:bindValue() ~= other:bindValue() + then + if Mode == 'definition' then + if other:action() == 'set' then + callback(other) + end + elseif Mode == 'reference' then + if other:action() == 'set' or other:action() == 'get' then + callback(other) + end + end + end + :: CONTINUE :: + end) +end + +local function parseLocal(callback, vm, source) + ---@type Local + local loc = source:bindLocal() + callback(loc:getSource()) + loc:eachInfo(function (info, src) + if Mode == 'definition' then + if info.type == 'set' or info.type == 'local' then + if vm.uri == src:getUri() then + if source.id >= src.id then + callback(src) + end + end + end + elseif Mode == 'reference' then + if info.type == 'set' or info.type == 'local' or info.type == 'return' or info.type == 'get' then + callback(src) + end + end + end) +end + +local function parseValueByValue(callback, vm, source, value) + if not source then + return + end + local mark = { [vm] = true } + local list = {} + for _ = 1, 5 do + value:eachInfo(function (info, src) + if Mode == 'definition' then + if info.type == 'local' then + if vm.uri == src:getUri() then + if source.id >= src.id then + callback(src) + end + end + end + if info.type == 'set' then + if vm.uri == src:getUri() then + if source.id >= src.id then + callback(src) + end + else + callback(src) + end + end + if info.type == 'return' then + if (src.type ~= 'simple' or src[#src].type == 'call') + and src.type ~= 'name' + then + callback(src) + end + if vm.lsp then + local destVM = vm.lsp:getVM(src:getUri()) + if destVM and not mark[destVM] then + mark[destVM] = true + list[#list+1] = { destVM, src } + end + end + end + elseif Mode == 'reference' then + if info.type == 'set' or info.type == 'local' or info.type == 'return' or info.type == 'get' then + callback(src) + end + end + end) + local nextData = table.remove(list, 1) + if nextData then + vm, source = nextData[1], nextData[2] + end + end +end + +local function parseValue(callback, vm, source) + local value = source:bindValue() + local isGlobal + if value then + isGlobal = value:isGlobal() + parseValueByValue(callback, vm, source, value) + local emmy = value:getEmmy() + if emmy and emmy.type == 'emmy.type' then + ---@type EmmyType + local emmyType = emmy + emmyType:eachClass(function (class) + if class and class:getValue() then + local emmyVM = vm + if vm.lsp then + local destVM = vm.lsp:getVM(class:getSource():getUri()) + if destVM then + emmyVM = destVM + end + end + parseValueByValue(callback, emmyVM, class:getValue():getSource(), class:getValue()) + end + end) + end + end + local parent = source:get 'parent' + for _ = 1, 3 do + if parent then + local ok = parent:eachInfo(function (info, src) + if Mode == 'definition' then + if info.type == 'set child' and info[1] == source[1] then + callback(src) + return true + end + elseif Mode == 'reference' then + if (info.type == 'set child' or info.type == 'get child') and info[1] == source[1] then + callback(src) + return true + end + end + end) + if ok then + break + end + parent = parent:getMetaMethod('__index') + end + end + return isGlobal +end + +local function parseLabel(callback, vm, label) + label:eachInfo(function (info, src) + if Mode == 'definition' then + if info.type == 'set' then + callback(src) + end + elseif Mode == 'reference' then + if info.type == 'set' or info.type == 'get' then + callback(src) + end + end + end) +end + +local function jumpUri(callback, vm, source) + local uri = source:get 'target uri' + callback { + start = 0, + finish = 0, + uri = uri + } +end + +local function parseClass(callback, vm, source) + local className = source:get 'emmy class' + vm.emmyMgr:eachClass(className, function (class) + if Mode == 'definition' then + if class.type == 'emmy.class' or class.type == 'emmy.alias' then + local src = class:getSource() + callback(src) + end + elseif Mode == 'reference' then + if class.type == 'emmy.class' or class.type == 'emmy.alias' or class.type == 'emmy.typeUnit' then + local src = class:getSource() + callback(src) + end + end + end) +end + +local function parseSee(callback, vm, source) + local see = source:get 'emmy see' + local className = see[1][1] + local childName = see[2][1] + vm.emmyMgr:eachClass(className, function (class) + ---@type value + local value = class:getValue() + local child = value:getChild(childName) + parseValueByValue(callback, vm, source, child) + end) +end + +local function parseFunction(callback, vm, source) + if Mode == 'definition' then + callback(source:bindFunction():getSource()) + source:bindFunction():eachInfo(function (info, src) + if info.type == 'set' or info.type == 'local' then + if vm.uri == src:getUri() then + if source.id >= src.id then + callback(src) + end + else + callback(src) + end + end + end) + elseif Mode == 'reference' then + callback(source:bindFunction():getSource()) + source:bindFunction():eachInfo(function (info, src) + if info.type == 'set' or info.type == 'local' or info.type == 'get' then + callback(src) + end + end) + end +end + +local function makeList(source) + local list = {} + local mark = {} + return list, function (src) + if mark[src] then + return + end + mark[src] = true + list[#list+1] = { + src.start, + src.finish, + src.uri + } + end +end + +return function (vm, pos, mode) + local filter = { + ['name'] = true, + ['string'] = true, + ['number'] = true, + ['boolean'] = true, + ['label'] = true, + ['goto'] = true, + ['function'] = true, + ['...'] = true, + ['emmyName'] = true, + ['emmyIncomplete'] = true, + } + local source = findSource(vm, pos, filter) + if not source then + return nil + end + Mode = mode + local list, callback = makeList(source) + local isGlobal + if source:bindLocal() then + parseLocal(callback, vm, source) + end + if source:bindValue() then + isGlobal = parseValue(callback, vm, source) + end + if source:bindLabel() then + parseLabel(callback, vm, source:bindLabel()) + end + if source:bindFunction() then + parseFunction(callback, vm, source) + end + if source:get 'target uri' then + jumpUri(callback, vm, source) + end + if source:get 'in index' then + isGlobal = parseValue(callback, vm, source) + end + if source:get 'emmy class' then + parseClass(callback, vm, source) + end + if source:get 'emmy see' then + parseSee(callback, vm, source) + end + + if #list == 0 then + parseValueSimily(callback, vm, source) + end + + return list, isGlobal +end diff --git a/script/core/diagnostics.lua b/script/core/diagnostics.lua new file mode 100644 index 00000000..3b11b818 --- /dev/null +++ b/script/core/diagnostics.lua @@ -0,0 +1,1042 @@ +local lang = require 'language' +local config = require 'config' +local library = require 'core.library' +local buildGlobal = require 'vm.global' +local DiagnosticSeverity = require 'constant.DiagnosticSeverity' +local DiagnosticDefaultSeverity = require 'constant.DiagnosticDefaultSeverity' +local DiagnosticTag = require 'constant.DiagnosticTag' + +local mt = {} +mt.__index = mt + +local function isContainPos(obj, start, finish) + if obj.start <= start and obj.finish >= finish then + return true + end + return false +end + +function mt:searchUnusedLocals(callback) + self.vm:eachSource(function (source) + local loc = source:bindLocal() + if not loc then + return + end + if loc:get 'emmy arg' then + return + end + local name = loc:getName() + if name == '_' or name == '_ENV' or name == '' then + return + end + if source:action() ~= 'local' then + return + end + if loc:get 'hide' then + return + end + local used = loc:eachInfo(function (info) + if info.type == 'get' then + return true + end + end) + if not used then + callback(source.start, source.finish, name) + end + end) +end + +function mt:searchUnusedFunctions(callback) + self.vm:eachSource(function (source) + local loc = source:bindLocal() + if not loc then + return + end + if loc:get 'emmy arg' then + return + end + if source:action() ~= 'local' then + return + end + if loc:get 'hide' then + return + end + local used = loc:eachInfo(function (info) + if info.type == 'get' then + return true + end + end) + if used then + return + end + loc:eachInfo(function (info, src) + if info.type == 'set' or info.type == 'local' then + local v = src:bindValue() + local func = v and v:getFunction() + if func and func:getSource().uri == self.vm.uri then + callback(func:getSource().start, func:getSource().finish) + end + end + end) + end) +end + +function mt:searchUndefinedGlobal(callback) + local definedGlobal = {} + for name in pairs(config.config.diagnostics.globals) do + definedGlobal[name] = true + end + local envValue = buildGlobal(self.vm.lsp) + envValue:eachInfo(function (info) + if info.type == 'set child' then + local name = info[1] + definedGlobal[name] = true + end + end) + self.vm:eachSource(function (source) + if not source:get 'global' then + return + end + local name = source:getName() + if name == '' then + return + end + local parent = source:get 'parent' + if not parent then + return + end + if not parent:get 'ENV' and not source:get 'in index' then + return + end + if definedGlobal[name] then + return + end + if type(name) ~= 'string' then + return + end + callback(source.start, source.finish, name) + end) +end + +function mt:searchUnusedLabel(callback) + self.vm:eachSource(function (source) + local label = source:bindLabel() + if not label then + return + end + if source:action() ~= 'set' then + return + end + local used = label:eachInfo(function (info) + if info.type == 'get' then + return true + end + end) + if not used then + callback(source.start, source.finish, label:getName()) + end + end) +end + +function mt:searchUnusedVararg(callback) + self.vm:eachSource(function (source) + local value = source:bindFunction() + if not value then + return + end + local func = value:getFunction() + if not func then + return + end + if func._dotsSource and not func._dotsLoad then + callback(func._dotsSource.start, func._dotsSource.finish) + end + end) +end + +local function isInString(vm, start, finish) + return vm:eachSource(function (source) + if source.type == 'string' and isContainPos(source, start, finish) then + return true + end + end) +end + +function mt:searchSpaces(callback) + local vm = self.vm + local lines = self.lines + for i = 1, #lines do + local line = lines:line(i) + + if line:find '^[ \t]+$' then + local start, finish = lines:range(i) + if isInString(vm, start, finish) then + goto NEXT_LINE + end + callback(start, finish, lang.script.DIAG_LINE_ONLY_SPACE) + goto NEXT_LINE + end + + local pos = line:find '[ \t]+$' + if pos then + local start, finish = lines:range(i) + start = start + pos - 1 + if isInString(vm, start, finish) then + goto NEXT_LINE + end + callback(start, finish, lang.script.DIAG_LINE_POST_SPACE) + goto NEXT_LINE + end + + ::NEXT_LINE:: + end +end + +function mt:searchRedefinition(callback) + local used = {} + local uri = self.uri + self.vm:eachSource(function (source) + local loc = source:bindLocal() + if not loc then + return + end + local shadow = loc:shadow() + if not shadow then + return + end + if used[shadow] then + return + end + used[shadow] = true + if loc:get 'hide' then + return + end + local name = loc:getName() + if name == '_' or name == '_ENV' or name == '' then + return + end + local related = {} + for i = 1, #shadow do + related[i] = { + start = shadow[i]:getSource().start, + finish = shadow[i]:getSource().finish, + uri = uri, + } + end + for i = 2, #shadow do + callback(shadow[i]:getSource().start, shadow[i]:getSource().finish, name, related) + end + end) +end + +function mt:searchNewLineCall(callback) + local lines = self.lines + self.vm:eachSource(function (source) + if source.type ~= 'simple' then + return + end + for i = 1, #source - 1 do + local callSource = source[i] + local funcSource = source[i-1] + if callSource.type ~= 'call' then + goto CONTINUE + end + local callLine = lines:rowcol(callSource.start) + local funcLine = lines:rowcol(funcSource.finish) + if callLine > funcLine then + callback(callSource.start, callSource.finish) + end + :: CONTINUE :: + end + end) +end + +function mt:searchNewFieldCall(callback) + local lines = self.lines + self.vm:eachSource(function (source) + if source.type ~= 'table' then + return + end + for i = 1, #source do + local field = source[i] + if field.type == 'simple' then + local callSource = field[#field] + local funcSource = field[#field-1] + local callLine = lines:rowcol(callSource.start) + local funcLine = lines:rowcol(funcSource.finish) + if callLine > funcLine then + callback(funcSource.start, callSource.finish + , lines.buf:sub(funcSource.start, funcSource.finish) + , lines.buf:sub(callSource.start, callSource.finish) + ) + end + end + end + end) +end + +function mt:searchRedundantParameters(callback) + self.vm:eachSource(function (source) + local args = source:bindCall() + if not args then + return + end + + -- 回调函数不检查 + local simple = source:get 'simple' + if simple and simple[2] == source then + local loc = simple[1]:bindLocal() + if loc then + local source = loc:getSource() + if source:get 'arg' then + return + end + end + end + + local value = source:findCallFunction() + if not value then + return + end + + local func = value:getFunction() + -- 参数中有 ... ,不用再检查了 + if func:hasDots() then + return + end + local max = #func.args + local passed = #args + -- function m.open() end + -- m:open() + -- 这种写法不算错 + if passed == 1 and source:get 'has object' then + return + end + for i = max + 1, passed do + local extra = args[i] + callback(extra.start, extra.finish, max, passed) + end + end) +end + +local opMap = { + ['+'] = true, + ['-'] = true, + ['*'] = true, + ['/'] = true, + ['//'] = true, + ['^'] = true, + ['<<'] = true, + ['>>'] = true, + ['&'] = true, + ['|'] = true, + ['~'] = true, + ['..'] = true, +} + +local literalMap = { + ['number'] = true, + ['boolean'] = true, + ['string'] = true, + ['table'] = true, +} + +function mt:searchAmbiguity1(callback) + self.vm:eachSource(function (source) + if source.op ~= 'or' then + return + end + local first = source[1] + local second = source[2] + -- a + (b or 0) --> (a + b) or 0 + do + if opMap[first.op] + and first.type ~= 'unary' + and not second.op + and literalMap[second.type] + and not first.brackets + then + callback(source.start, source.finish, first.start, first.finish) + end + end + -- (a or 0) + c --> a or (0 + c) + do + if opMap[second.op] + and second.type ~= 'unary' + and not first.op + and literalMap[second[1].type] + and not second.brackets + then + callback(source.start, source.finish, second.start, second.finish) + end + end + end) +end + +function mt:searchLowercaseGlobal(callback) + local definedGlobal = {} + for name in pairs(config.config.diagnostics.globals) do + definedGlobal[name] = true + end + for name in pairs(library.global) do + definedGlobal[name] = true + end + self.vm:eachSource(function (source) + if source.type == 'name' + and source:get 'parent' + and not source:get 'simple' + and not source:get 'table index' + and source:action() == 'set' + then + local name = source[1] + if definedGlobal[name] then + return + end + local first = name:match '%w' + if not first then + return + end + if first:match '%l' then + callback(source.start, source.finish) + end + end + end) +end + +function mt:searchDuplicateIndex(callback) + self.vm:eachSource(function (source) + if source.type ~= 'table' then + return + end + local mark = {} + for _, obj in ipairs(source) do + if obj.type == 'pair' then + local key = obj[1] + local name + if key.index then + if key.type == 'string' then + name = key[1] + end + elseif key.type == 'name' then + name = key[1] + end + if name then + if mark[name] then + mark[name][#mark[name]+1] = obj + else + mark[name] = { obj } + end + end + end + end + for name, defs in pairs(mark) do + if #defs > 1 then + local related = {} + for i = 1, #defs do + related[i] = { + start = defs[i][1].start, + finish = defs[i][2].finish, + uri = self.uri, + } + end + for i = 1, #defs - 1 do + callback(defs[i][1].start, defs[i][2].finish, name, related, 'unused') + end + for i = #defs, #defs do + callback(defs[i][1].start, defs[i][1].finish, name, related, 'duplicate') + end + end + end + end) +end + +function mt:searchDuplicateMethod(callback) + local uri = self.uri + local mark = {} + local map = {} + self.vm:eachSource(function (source) + local parent = source:get 'parent' + if not parent then + return + end + if mark[parent] then + return + end + mark[parent] = true + local relates = {} + parent:eachInfo(function (info, src) + local k = info[1] + if info.type ~= 'set child' then + return + end + if type(k) ~= 'string' then + return + end + if src.start == 0 then + return + end + if not src:get 'object' then + return + end + if map[src] then + return + end + if not relates[k] then + relates[k] = map[src] or { + name = k, + } + end + map[src] = relates[k] + relates[k][#relates[k]+1] = { + start = src.start, + finish = src.finish, + uri = src.uri + } + end) + end) + for src, relate in pairs(map) do + if #relate > 1 and src.uri == uri then + callback(src.start, src.finish, relate.name, relate) + end + end +end + +function mt:searchEmptyBlock(callback) + self.vm:eachSource(function (source) + -- 认为空repeat与空while是合法的 + -- 要去vm中激活source + if source.type == 'if' then + for _, block in ipairs(source) do + if #block > 0 then + return + end + end + callback(source.start, source.finish) + return + end + if source.type == 'loop' + or source.type == 'in' + then + if #source == 0 then + callback(source.start, source.finish) + end + return + end + end) +end + +function mt:searchRedundantValue(callback) + self.vm:eachSource(function (source) + if source.type == 'set' or source.type == 'local' then + local args = source[1] + local values = source[2] + if not source[2] then + return + end + local argCount, valueCount + if args.type == 'list' then + argCount = #args + else + argCount = 1 + end + if values.type == 'list' then + valueCount = #values + else + valueCount = 1 + end + for i = argCount + 1, valueCount do + local value = values[i] + callback(value.start, value.finish, argCount, valueCount) + end + end + end) +end + +function mt:searchUndefinedEnvChild(callback) + self.vm:eachSource(function (source) + if not source:get 'global' then + return + end + local name = source:getName() + if name == '' then + return + end + if source:get 'in index' then + return + end + local parent = source:get 'parent' + if parent:get 'ENV' then + return + end + local value = source:bindValue() + if not value then + return + end + if value:getSource() == source then + callback(source.start, source.finish, name) + end + return + end) +end + +function mt:searchGlobalInNilEnv(callback) + self.vm:eachSource(function (source) + if not source:get 'global' then + return + end + local name = source:getName() + if name == '' then + return + end + local parentSource = source:get 'parent' :getSource() + if parentSource and parentSource.type == 'nil' then + callback(source.start, source.finish, { + { + start = parentSource.start, + finish = parentSource.finish, + uri = self.uri, + } + }) + end + return + end) +end + +function mt:checkEmmyClass(source, callback) + local class = source:get 'emmy.class' + if not class then + return + end + -- class重复定义 + local name = class:getName() + local related = {} + self.vm.emmyMgr:eachClass(name, function (class) + if class.type ~= 'emmy.class' and class.type ~= 'emmy.alias' then + return + end + local src = class:getSource() + if src ~= source then + related[#related+1] = { + start = src.start, + finish = src.finish, + uri = src.uri, + } + end + end) + if #related > 0 then + callback(source.start, source.finish, lang.script.DIAG_DUPLICATE_CLASS ,related) + end + -- 继承不存在的class + local extends = class.extends + if not extends then + return + end + local parent = self.vm.emmyMgr:eachClass(extends, function (parent) + if parent.type == 'emmy.class' then + return parent + end + end) + if not parent then + callback(source[2].start, source[2].finish, lang.script.DIAG_UNDEFINED_CLASS) + return + end + + -- class循环继承 + local related = {} + local current = class + for _ = 1, 10 do + local extends = current.extends + if not extends then + break + end + related[#related+1] = { + start = current:getSource().start, + finish = current:getSource().finish, + uri = current:getSource().uri, + } + current = self.vm.emmyMgr:eachClass(extends, function (parent) + if parent.type == 'emmy.class' then + return parent + end + end) + if not current then + break + end + if current:getName() == class:getName() then + callback(source.start, source.finish, lang.script.DIAG_CYCLIC_EXTENDS, related) + break + end + end +end + +function mt:checkEmmyType(source, callback) + for _, tpsource in ipairs(source) do + local name = tpsource[1] + local class = self.vm.emmyMgr:eachClass(name, function (class) + if class.type == 'emmy.class' or class.type == 'emmy.alias' then + return class + end + end) + if not class then + callback(tpsource.start, tpsource.finish, lang.script.DIAG_UNDEFINED_CLASS) + end + end +end + +function mt:checkEmmyAlias(source, callback) + local class = source:get 'emmy.alias' + if not class then + return + end + -- class重复定义 + local name = class:getName() + local related = {} + self.vm.emmyMgr:eachClass(name, function (class) + if class.type ~= 'emmy.class' and class.type ~= 'emmy.alias' then + return + end + local src = class:getSource() + if src ~= source then + related[#related+1] = { + start = src.start, + finish = src.finish, + uri = src.uri, + } + end + end) + if #related > 0 then + callback(source.start, source.finish, lang.script.DIAG_DUPLICATE_CLASS ,related) + end +end + +function mt:checkEmmyParam(source, callback, mark) + local func = source:get 'emmy function' + if not func then + return + end + if mark[func] then + return + end + mark[func] = true + + -- 检查不存在的参数 + local emmyParams = func:getEmmyParams() + local funcParams = {} + if func.args then + for _, arg in ipairs(func.args) do + funcParams[arg.name] = true + end + end + for _, param in ipairs(emmyParams) do + local name = param:getName() + if not funcParams[name] then + callback(param:getSource()[1].start, param:getSource()[1].finish, lang.script.DIAG_INEXISTENT_PARAM) + end + end + + -- 检查重复的param + local lists = {} + for _, param in ipairs(emmyParams) do + local name = param:getName() + if not lists[name] then + lists[name] = {} + end + lists[name][#lists[name]+1] = param:getSource()[1] + end + for _, list in pairs(lists) do + if #list > 1 then + local related = {} + for _, src in ipairs(list) do + related[#related+1] = { + src.start, + src.finish, + src.uri, + } + callback(src.start, src.finish, lang.script.DIAG_DUPLICATE_PARAM) + end + end + end +end + +function mt:checkEmmyField(source, callback, mark) + ---@type EmmyClass + local class = source:get 'target class' + -- 必须写在 class 的后面 + if not class then + callback(source.start, source.finish, lang.script.DIAG_NEED_CLASS) + end + + -- 检查重复的 field + if class and not mark[class] then + mark[class] = true + local lists = {} + class:eachField(function (field) + local name = field:getName() + if not lists[name] then + lists[name] = {} + end + lists[name][#lists[name]+1] = field:getSource()[2] + end) + for _, list in pairs(lists) do + if #list > 1 then + local related = {} + for _, src in ipairs(list) do + related[#related+1] = { + src.start, + src.finish, + src.uri, + } + callback(src.start, src.finish, lang.script.DIAG_DUPLICATE_FIELD) + end + end + end + end +end + +function mt:searchEmmyLua(callback) + local mark = {} + self.vm:eachSource(function (source) + if source.type == 'emmyClass' then + self:checkEmmyClass(source, callback) + elseif source.type == 'emmyType' then + self:checkEmmyType(source, callback) + elseif source.type == 'emmyAlias' then + self:checkEmmyAlias(source, callback) + elseif source.type == 'emmyParam' then + self:checkEmmyParam(source, callback, mark) + elseif source.type == 'emmyField' then + self:checkEmmyField(source, callback, mark) + end + end) +end + +function mt:searchSetConstLocal(callback) + local mark = {} + self.vm:eachSource(function (source) + local loc = source:bindLocal() + if not loc then + return + end + if mark[loc] then + return + end + mark[loc] = true + if not loc.tags then + return + end + local const + for _, tag in ipairs(loc.tags) do + if tag[1] == 'const' then + const = true + break + end + end + if not const then + return + end + loc:eachInfo(function (info, src) + if info.type == 'set' then + callback(src.start, src.finish) + end + end) + end) +end + +function mt:doDiagnostics(func, code, callback) + if config.config.diagnostics.disable[code] then + return + end + local level = config.config.diagnostics.severity[code] + if not DiagnosticSeverity[level] then + level = DiagnosticDefaultSeverity[code] + end + func(self, function (start, finish, ...) + local data = callback(...) + data.code = code + data.start = start + data.finish = finish + data.level = data.level or DiagnosticSeverity[level] + self.datas[#self.datas+1] = data + end) + if coroutine.isyieldable() then + if self.vm:isRemoved() then + coroutine.yield('stop') + else + coroutine.yield() + end + end +end + +return function (vm, lines, uri) + local session = setmetatable({ + vm = vm, + lines = lines, + uri = uri, + datas = {}, + }, mt) + + -- 未使用的局部变量 + session:doDiagnostics(session.searchUnusedLocals, 'unused-local', function (key) + return { + message = lang.script('DIAG_UNUSED_LOCAL', key), + tags = {DiagnosticTag.Unnecessary}, + } + end) + -- 未使用的函数 + session:doDiagnostics(session.searchUnusedFunctions, 'unused-function', function () + return { + message = lang.script.DIAG_UNUSED_FUNCTION, + tags = {DiagnosticTag.Unnecessary}, + } + end) + -- 读取未定义全局变量 + session:doDiagnostics(session.searchUndefinedGlobal, 'undefined-global', function (key) + local message = lang.script('DIAG_UNDEF_GLOBAL', key) + local otherVersion = library.other[key] + local customLib = library.custom[key] + if otherVersion then + message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(otherVersion, '/'), config.config.runtime.version)) + end + if customLib then + message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_CUSTOM', table.concat(customLib, '/'))) + end + return { + message = message, + } + end) + -- 未使用的Label + session:doDiagnostics(session.searchUnusedLabel, 'unused-label', function (key) + return { + message = lang.script('DIAG_UNUSED_LABEL', key), + tags = {DiagnosticTag.Unnecessary}, + } + end) + -- 未使用的不定参数 + session:doDiagnostics(session.searchUnusedVararg, 'unused-vararg', function () + return { + message = lang.script.DIAG_UNUSED_VARARG, + tags = {DiagnosticTag.Unnecessary}, + } + end) + -- 只有空格与制表符的行,以及后置空格 + session:doDiagnostics(session.searchSpaces, 'trailing-space', function (message) + return { + message = message, + } + end) + -- 重定义局部变量 + session:doDiagnostics(session.searchRedefinition, 'redefined-local', function (key, related) + return { + message = lang.script('DIAG_REDEFINED_LOCAL', key), + related = related, + } + end) + -- 以括号开始的一行(可能被误解析为了上一行的call) + session:doDiagnostics(session.searchNewLineCall, 'newline-call', function () + return { + message = lang.script.DIAG_PREVIOUS_CALL, + } + end) + -- 以字符串开始的field(可能被误解析为了上一行的call) + session:doDiagnostics(session.searchNewFieldCall, 'newfield-call', function (func, call) + return { + message = lang.script('DIAG_PREFIELD_CALL', func, call), + } + end) + -- 调用函数时的参数数量是否超过函数的接收数量 + session:doDiagnostics(session.searchRedundantParameters, 'redundant-parameter', function (max, passed) + return { + message = lang.script('DIAG_OVER_MAX_ARGS', max, passed), + tags = {DiagnosticTag.Unnecessary}, + } + end) + -- x or 0 + 1 + session:doDiagnostics(session.searchAmbiguity1, 'ambiguity-1', function (start, finish) + return { + message = lang.script('DIAG_AMBIGUITY_1', lines.buf:sub(start, finish)), + } + end) + -- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删) + session:doDiagnostics(session.searchLowercaseGlobal, 'lowercase-global', function () + return { + message = lang.script.DIAG_LOWERCASE_GLOBAL, + } + end) + -- 未定义的变量(重载了 `_ENV`) + session:doDiagnostics(session.searchUndefinedEnvChild, 'undefined-env-child', function (key) + if vm.envType == '_ENV' then + return { + message = lang.script('DIAG_UNDEF_ENV_CHILD', key), + } + else + return { + message = lang.script('DIAG_UNDEF_FENV_CHILD', key), + } + end + end) + -- 全局变量不可用(置空了 `_ENV`) + session:doDiagnostics(session.searchGlobalInNilEnv, 'global-in-nil-env', function (related) + if vm.envType == '_ENV' then + return { + message = lang.script.DIAG_GLOBAL_IN_NIL_ENV, + related = related, + } + else + return { + message = lang.script.DIAG_GLOBAL_IN_NIL_FENV, + related = related, + } + end + end) + -- 构建表时重复定义field + session:doDiagnostics(session.searchDuplicateIndex, 'duplicate-index', function (key, related, type) + if type == 'unused' then + return { + message = lang.script('DIAG_DUPLICATE_INDEX', key), + related = related, + level = DiagnosticSeverity.Hint, + tags = {DiagnosticTag.Unnecessary}, + } + else + return { + message = lang.script('DIAG_DUPLICATE_INDEX', key), + related = related, + } + end + end) + -- 往表里面塞重复的method + --session:doDiagnostics(session.searchDuplicateMethod, 'duplicate-method', function (key, related) + -- return { + -- message = lang.script('DIAG_DUPLICATE_METHOD', key), + -- related = related, + -- } + --end) + -- 空代码块 + session:doDiagnostics(session.searchEmptyBlock, 'empty-block', function () + return { + message = lang.script.DIAG_EMPTY_BLOCK, + tags = {DiagnosticTag.Unnecessary}, + } + end) + -- 多余的赋值 + session:doDiagnostics(session.searchRedundantValue, 'redundant-value', function (max, passed) + return { + message = lang.script('DIAG_OVER_MAX_VALUES', max, passed), + tags = {DiagnosticTag.Unnecessary}, + } + end) + -- Emmy相关的检查 + session:doDiagnostics(session.searchEmmyLua, 'emmy-lua', function (message, related) + return { + message = message, + related = related, + } + end) + -- 检查给const变量赋值 + session:doDiagnostics(session.searchSetConstLocal, 'set-const', function () + return { + message = lang.script.DIAG_SET_CONST + } + end) + return session.datas +end diff --git a/script/core/document_symbol.lua b/script/core/document_symbol.lua new file mode 100644 index 00000000..48e01332 --- /dev/null +++ b/script/core/document_symbol.lua @@ -0,0 +1,260 @@ +local hoverFunction = require 'core.hover.function' +local getName = require 'core.name' +local hover = require 'core.hover' + +local SymbolKind = { + File = 1, + Module = 2, + Namespace = 3, + Package = 4, + Class = 5, + Method = 6, + Property = 7, + Field = 8, + Constructor = 9, + Enum = 10, + Interface = 11, + Function = 12, + Variable = 13, + Constant = 14, + String = 15, + Number = 16, + Boolean = 17, + Array = 18, + Object = 19, + Key = 20, + Null = 21, + EnumMember = 22, + Struct = 23, + Event = 24, + Operator = 25, + TypeParameter = 26, +} + +local function buildLocal(vm, source, used, callback) + local vars = source[1] + local exps = source[2] + if vars.type ~= 'list' then + vars = {vars} + end + if not exps or exps.type ~= 'list' then + exps = {exps} + end + for i, var in ipairs(vars) do + local exp = exps[i] + local data = {} + local loc = var:bindLocal() + data.name = loc:getName() + data.range = { var.start, var.finish } + data.selectionRange = { var.start, var.finish } + if exp then + local hvr = hover(var) + if exp.type == 'function' then + data.kind = SymbolKind.Function + else + data.kind = SymbolKind.Variable + end + data.detail = hvr.label:gsub('[\r\n]', '') + data.valueRange = { exp.start, exp.finish } + used[exp] = true + else + data.kind = SymbolKind.Variable + data.detail = '' + data.valueRange = { var.start, var.finish } + end + callback(data) + end +end + +local function buildSet(vm, source, used, callback) + local vars = source[1] + local exps = source[2] + if vars.type ~= 'list' then + vars = {vars} + end + if not exps or exps.type ~= 'list' then + exps = {exps} + end + for i, var in ipairs(vars) do + if var:bindLocal() then + goto CONTINUE + end + local exp = exps[i] + local data = {} + data.name = getName(var) + data.range = { var.start, var.finish } + data.selectionRange = { var.start, var.finish } + if exp then + local hvr = hover(var) + if not hvr then + goto CONTINUE + end + if exp.type == 'function' then + data.kind = SymbolKind.Function + else + data.kind = SymbolKind.Property + end + data.detail = hvr.label:gsub('[\r\n]', '') + data.valueRange = { exp.start, exp.finish } + used[exp] = true + else + data.kind = SymbolKind.Property + data.detail = '' + data.valueRange = { var.start, var.finish } + end + callback(data) + :: CONTINUE :: + end +end + +local function buildPair(vm, source, used, callback) + local var = source[1] + local exp = source[2] + local data = {} + data.name = getName(var) + data.range = { var.start, var.finish } + data.selectionRange = { var.start, var.finish } + if exp then + local hvr = hover(var) + if not hvr then + return + end + if exp.type == 'function' then + data.kind = SymbolKind.Function + else + data.kind = SymbolKind.Class + end + data.detail = hvr.label:gsub('[\r\n]', '') + data.valueRange = { exp.start, exp.finish } + used[exp] = true + else + data.kind = SymbolKind.Class + data.detail = '' + data.valueRange = { var.start, var.finish } + end + callback(data) +end + +local function buildLocalFunction(vm, source, used, callback) + local value = source:bindFunction() + if not value then + return + end + local name = getName(source.name) + local hvr = hoverFunction(name, value:getFunction()) + if not hvr then + return + end + local kind = SymbolKind.Function + callback { + name = name, + detail = hvr.label:gsub('[\r\n]', ''), + kind = kind, + range = { source.start, source.finish }, + selectionRange = { source.name.start, source.name.finish }, + valueRange = { source.start, source.finish }, + } +end + + +local function buildFunction(vm, source, used, callback) + if used[source] then + return + end + local value = source:bindFunction() + if not value then + return + end + local name = getName(source.name) + local func = value:getFunction() + if not func then + return + end + local hvr = hoverFunction(name, func, func:getObject()) + if not hvr then + return + end + local data = {} + data.name = name + data.detail = hvr.label:gsub('[\r\n]', '') + data.range = { source.start, source.finish } + data.valueRange = { source.start, source.finish } + if source.name then + data.selectionRange = { source.name.start, source.name.finish } + else + data.selectionRange = { source.start, source.start } + end + if func:getObject() then + data.kind = SymbolKind.Field + else + data.kind = SymbolKind.Function + end + callback(data) +end + +local function buildSource(vm, source, used, callback) + if source.type == 'local' then + buildLocal(vm, source, used, callback) + return + end + if source.type == 'set' then + buildSet(vm, source, used, callback) + return + end + if source.type == 'pair' then + buildPair(vm, source, used, callback) + return + end + if source.type == 'localfunction' then + buildLocalFunction(vm, source, used, callback) + return + end + if source.type == 'function' then + buildFunction(vm, source, used, callback) + return + end +end + +local function packChild(symbols, finish, kind) + local t + while true do + local symbol = symbols[#symbols] + if not symbol then + break + end + if symbol.valueRange[1] > finish then + break + end + symbols[#symbols] = nil + symbol.children = packChild(symbols, symbol.valueRange[2], symbol.kind) + if not t then + t = {} + end + t[#t+1] = symbol + end + return t +end + +local function packSymbols(symbols) + -- 按照start位置反向排序 + table.sort(symbols, function (a, b) + return a.range[1] > b.range[1] + end) + -- 处理嵌套 + return packChild(symbols, math.maxinteger, SymbolKind.Function) +end + +return function (vm) + local symbols = {} + local used = {} + + vm:eachSource(function (source) + buildSource(vm, source, used, function (data) + symbols[#symbols+1] = data + end) + end) + + local packedSymbols = packSymbols(symbols) + + return packedSymbols +end diff --git a/script/core/find_lib.lua b/script/core/find_lib.lua new file mode 100644 index 00000000..e76549a8 --- /dev/null +++ b/script/core/find_lib.lua @@ -0,0 +1,65 @@ +local hoverName = require 'core.hover.name' + +local function getParentName(lib, isObject) + for _, parent in ipairs(lib.parent) do + if isObject then + if parent.type == 'object' then + return parent.nick or parent.name + end + else + if parent.type ~= 'object' then + return parent.nick or parent.name + end + end + end + return '' +end + +local function findLib(source) + local value = source:bindValue() + local lib = value:getLib() + if not lib then + return nil + end + if lib.parent then + if source:get 'object' then + -- *string:sub + local fullKey = ('*%s:%s'):format(getParentName(lib, true), lib.name) + return lib, fullKey + else + local parentValue = source:get 'parent' + if parentValue and parentValue:getType() == 'string' then + -- *string.sub + local fullKey = ('*%s.%s'):format(getParentName(lib, false), lib.name) + return lib, fullKey + else + -- string.sub + local fullKey = ('%s.%s'):format(getParentName(lib, false), lib.name) + return lib, fullKey + end + end + else + local name = hoverName(source) + local libName = lib.nick or lib.name + if name == libName or not libName then + return lib, name + elseif name == '' then + return lib, libName + else + return lib, ('%s<%s>'):format(name, libName) + end + end +end + +return function (source) + if source:bindValue() then + local lib, fullKey = findLib(source) + return lib, fullKey + end + if source:get 'in index' then + source = source:get 'in index' + local lib, fullKey = findLib(source) + return lib, fullKey + end + return nil +end diff --git a/script/core/find_source.lua b/script/core/find_source.lua new file mode 100644 index 00000000..a64a047e --- /dev/null +++ b/script/core/find_source.lua @@ -0,0 +1,59 @@ +local function isContainPos(obj, pos) + if obj.start <= pos and obj.finish >= pos then + return true + end + return false +end + +local function isValidSource(source) + return source.start ~= nil and source.start ~= 0 +end + +local function matchFilter(source, filter) + if not filter then + return true + end + return filter[source.type] +end + +local function findAtPos(vm, pos, filter) + local res = {} + vm:eachSource(function (source) + if isValidSource(source) + and isContainPos(source, pos) + and matchFilter(source, filter) + then + res[#res+1] = source + end + end) + if #res == 0 then + return nil + end + table.sort(res, function (a, b) + if a == b then + return false + end + local rangeA = a.finish - a.start + local rangeB = b.finish - b.start + -- 特殊处理:func 'str' 的情况下,list与string的范围会完全相同,此时取string + if rangeA == rangeB then + if b.type == 'call' and #b == 1 and b[1] == a then + return true + elseif a.type == 'call' and #a == 1 and a[1] == b then + return false + else + return a.id < b.id + end + end + return rangeA < rangeB + end) + local source = res[1] + if not source then + return nil + end + return source +end + +return function (vm, pos, filter) + return findAtPos(vm, pos, filter) +end diff --git a/script/core/folding_range.lua b/script/core/folding_range.lua new file mode 100644 index 00000000..e94d1ffe --- /dev/null +++ b/script/core/folding_range.lua @@ -0,0 +1,73 @@ +local foldingType = { + ['function'] = {'region', 'end', }, + ['localfunction'] = {'region', 'end', }, + ['do'] = {'region', 'end', }, + ['if'] = {'region', 'end', }, + ['loop'] = {'region', 'end', }, + ['in'] = {'region', 'end', }, + ['while'] = {'region', 'end', }, + ['repeat'] = {'region', 'until',}, + ['table'] = {'region', '}', }, + ['string'] = {'regtion', ']', }, +} + +return function (vm, comments) + local result = {} + vm:eachSource(function (source) + local tp = source.type + local data = foldingType[tp] + if not data then + return + end + local start = source.start + local finish = source.finish + if tp == 'repeat' then + if #source > 0 then + finish = source[#source].finish + else + finish = start + #'repeat' + end + finish = vm.text:find('until', finish, true) or finish + result[#result+1] = { + start = start, + finish = finish, + kind = data[1], + } + elseif tp == 'if' then + for i = 1, #source do + local block = source[i] + local nblock = source[i+1] + result[#result+1] = { + start = block.start, + finish = nblock and nblock.start or finish, + kind = data[1], + } + end + elseif tp == 'string' then + result[#result+1] = { + start = start, + finish = finish, + kind = data[1], + } + elseif data[1] == 'region' then + result[#result+1] = { + start = start, + finish = finish, + kind = data[1], + } + end + end) + if comments then + for _, comment in ipairs(comments) do + result[#result+1] = { + start = comment.start, + finish = comment.finish, + kind = 'comment', + } + end + end + if #result == 0 then + return nil + end + return result +end diff --git a/script/core/global.lua b/script/core/global.lua new file mode 100644 index 00000000..961ad304 --- /dev/null +++ b/script/core/global.lua @@ -0,0 +1,49 @@ +local mt = {} +mt.__index = mt + +function mt:markSet(uri) + if not uri then + return + end + self.set[uri] = true +end + +function mt:markGet(uri) + if not uri then + return + end + self.get[uri] = true +end + +function mt:clearGlobal(uri) + self.set[uri] = nil + self.get[uri] = nil +end + +function mt:getAllUris() + local uris = {} + for uri in pairs(self.set) do + uris[#uris+1] = uri + end + for uri in pairs(self.get) do + if not self.set[uri] then + uris[#uris+1] = uri + end + end + return uris +end + +function mt:hasSetGlobal(uri) + return self.set[uri] ~= nil +end + +function mt:remove() +end + +return function (lsp) + return setmetatable({ + get = {}, + set = {}, + lsp = lsp, + }, mt) +end diff --git a/script/core/highlight.lua b/script/core/highlight.lua new file mode 100644 index 00000000..2073573d --- /dev/null +++ b/script/core/highlight.lua @@ -0,0 +1,54 @@ +local findSource = require 'core.find_source' +local parser = require 'parser' + +local DocumentHighlightKind = { + Text = 1, + Read = 2, + Write = 3, +} + +local function parseResult(source) + local positions = {} + if source:bindLabel() then + source:bindLabel():eachInfo(function (info, src) + positions[#positions+1] = { src.start, src.finish, DocumentHighlightKind.Text } + end) + return positions + end + if source:bindLocal() then + local loc = source:bindLocal() + local mark = {} + loc:eachInfo(function (info, src) + if not mark[src] then + mark[src] = info + positions[#positions+1] = { src.start, src.finish, DocumentHighlightKind.Text } + end + end) + return positions + end + if source:bindValue() and source:get 'parent' then + local parent = source:get 'parent' + local mark = {} + parent:eachInfo(function (info, src) + if not mark[src] and source.uri == src.uri then + mark[src] = info + if info.type == 'get child' or info.type == 'set child' then + if info[1] == source[1] then + positions[#positions+1] = {src.start, src.finish, DocumentHighlightKind.Text} + end + end + end + end) + return positions + end + return nil +end + +return function (vm, pos) + local source = findSource(vm, pos) + if not source then + return nil + end + local positions = parseResult(source) + return positions +end diff --git a/script/core/hover/emmy_function.lua b/script/core/hover/emmy_function.lua new file mode 100644 index 00000000..7c87954e --- /dev/null +++ b/script/core/hover/emmy_function.lua @@ -0,0 +1,143 @@ +---@param emmy EmmyFunctionType +local function buildEmmyArgs(emmy, object, select) + local start + if object then + start = 2 + else + start = 1 + end + local strs = {} + local args = {} + local i = 0 + emmy:eachParam(function (name, typeObj) + i = i + 1 + if i < start then + return + end + if i > start then + strs[#strs+1] = ', ' + end + if i == select then + strs[#strs+1] = '@ARG' + end + strs[#strs+1] = name .. ': ' .. typeObj:getType() + args[#args+1] = strs[#strs] + if i == select then + strs[#strs+1] = '@ARG' + end + end) + local text = table.concat(strs) + local argLabel = {} + for i = 1, 2 do + local pos = text:find('@ARG', 1, true) + if pos then + if i == 1 then + argLabel[i] = pos + else + argLabel[i] = pos - 1 + end + text = text:sub(1, pos-1) .. text:sub(pos+4) + end + end + if #argLabel == 0 then + argLabel = nil + end + return text, argLabel, args +end + +local function buildEmmyReturns(emmy) + local rtns = {} + local i = 0 + emmy:eachReturn(function (rtn) + i = i + 1 + if i > 1 then + rtns[#rtns+1] = ('\n% 3d. '):format(i) + end + rtns[#rtns+1] = rtn:getType() + end) + if #rtns == 0 then + return '\n -> ' .. 'any' + else + return '\n -> ' .. table.concat(rtns) + end +end + +local function buildEnum(lib) + if not lib.enums then + return '' + end + local container = table.container() + for _, enum in ipairs(lib.enums) do + if not enum.name or (not enum.enum and not enum.code) then + goto NEXT_ENUM + end + if not container[enum.name] then + container[enum.name] = {} + if lib.args then + for _, arg in ipairs(lib.args) do + if arg.name == enum.name then + container[enum.name].type = arg.type + break + end + end + end + if lib.returns then + for _, rtn in ipairs(lib.returns) do + if rtn.name == enum.name then + container[enum.name].type = rtn.type + break + end + end + end + end + table.insert(container[enum.name], enum) + ::NEXT_ENUM:: + end + local strs = {} + local raw = {} + for name, enums in pairs(container) do + local tp + if type(enums.type) == 'table' then + tp = table.concat(enums.type, '/') + else + tp = enums.type + end + raw[name] = {} + strs[#strs+1] = ('\n%s: %s'):format(name, tp or 'any') + for _, enum in ipairs(enums) do + if enum.default then + strs[#strs+1] = '\n -> ' + else + strs[#strs+1] = '\n | ' + end + if enum.code then + strs[#strs+1] = tostring(enum.code) + else + strs[#strs+1] = ('%q'):format(enum.enum) + end + raw[name][#raw[name]+1] = strs[#strs] + if enum.description then + strs[#strs+1] = ' -- ' .. enum.description + end + end + end + return table.concat(strs), raw +end + +return function (name, emmy, object, select) + local argStr, argLabel, args = buildEmmyArgs(emmy, object, select) + local returns = buildEmmyReturns(emmy) + local enum, rawEnum = buildEnum(emmy) + local tip = emmy.description + return { + label = ('function %s(%s)%s'):format(name, argStr, returns), + name = name, + argStr = argStr, + returns = returns, + description = tip, + enum = enum, + rawEnum = rawEnum, + argLabel = argLabel, + args = args, + } +end diff --git a/script/core/hover/function.lua b/script/core/hover/function.lua new file mode 100644 index 00000000..3865f602 --- /dev/null +++ b/script/core/hover/function.lua @@ -0,0 +1,243 @@ +local emmyFunction = require 'core.hover.emmy_function' + +local function buildValueArgs(func, object, select) + if not func then + return '', nil + end + local names = {} + local values = {} + local options = {} + if func.argValues then + for i, value in ipairs(func.argValues) do + values[i] = value:getType() + end + end + if func.args then + for i, arg in ipairs(func.args) do + names[#names+1] = arg:getName() + local param = func:findEmmyParamByName(arg:getName()) + if param then + values[i] = param:getType() + options[i] = param:getOption() + end + end + end + local strs = {} + local start = 1 + if object then + start = 2 + end + local max + if func:getSource() then + max = #names + else + max = math.max(#names, #values) + end + local args = {} + for i = start, max do + local name = names[i] + local value = values[i] or 'any' + local option = options[i] + if option and option.optional then + if i > start then + strs[#strs+1] = ' [' + else + strs[#strs+1] = '[' + end + end + if i > start then + strs[#strs+1] = ', ' + end + + if i == select then + strs[#strs+1] = '@ARG' + end + if name then + strs[#strs+1] = name .. ': ' .. value + else + strs[#strs+1] = value + end + args[#args+1] = strs[#strs] + if i == select then + strs[#strs+1] = '@ARG' + end + + if option and option.optional == 'self' then + strs[#strs+1] = ']' + end + end + if func:hasDots() then + if max > 0 then + strs[#strs+1] = ', ' + end + strs[#strs+1] = '...' + end + + if options then + for _, option in pairs(options) do + if option.optional == 'after' then + strs[#strs+1] = ']' + end + end + end + + local text = table.concat(strs) + local argLabel = {} + for i = 1, 2 do + local pos = text:find('@ARG', 1, true) + if pos then + if i == 1 then + argLabel[i] = pos + else + argLabel[i] = pos - 1 + end + text = text:sub(1, pos-1) .. text:sub(pos+4) + end + end + if #argLabel == 0 then + argLabel = nil + end + return text, argLabel, args +end + +local function buildValueReturns(func) + if not func then + return '\n -> any' + end + if not func:get 'hasReturn' then + return '' + end + local strs = {} + local emmys = {} + local n = 0 + func:eachEmmyReturn(function (emmy) + n = n + 1 + emmys[n] = emmy + end) + if func.returns then + for i, rtn in ipairs(func.returns) do + local emmy = emmys[i] + local option = emmy and emmy.option + if option and option.optional then + if i > 1 then + strs[#strs+1] = ' [' + else + strs[#strs+1] = '[' + end + end + if i > 1 then + strs[#strs+1] = ('\n% 3d. '):format(i) + end + if emmy and emmy.name then + strs[#strs+1] = ('%s: '):format(emmy.name) + elseif option and option.name then + strs[#strs+1] = ('%s: '):format(option.name) + end + strs[#strs+1] = rtn:getType() + if option and option.optional == 'self' then + strs[#strs+1] = ']' + end + end + for i = 1, #func.returns do + local emmy = emmys[i] + if emmy and emmy.option and emmy.option.optional == 'after' then + strs[#strs+1] = ']' + end + end + end + if #strs == 0 then + strs[1] = 'any' + end + return '\n -> ' .. table.concat(strs) +end + +---@param func emmyFunction +local function buildEnum(func) + if not func then + return nil + end + local params = func:getEmmyParams() + if not params then + return nil + end + local strs = {} + local raw = {} + for _, param in ipairs(params) do + local first = true + local name = param:getName() + raw[name] = {} + param:eachEnum(function (enum) + if first then + first = false + strs[#strs+1] = ('\n%s: %s'):format(param:getName(), param:getType()) + end + if enum.default then + strs[#strs+1] = ('\n |>%s'):format(enum[1]) + else + strs[#strs+1] = ('\n | %s'):format(enum[1]) + end + if enum.comment then + strs[#strs+1] = ' -- ' .. enum.comment + end + raw[name][#raw[name]+1] = enum[1] + end) + end + if #strs == 0 then + return nil + end + return table.concat(strs), raw +end + +local function getComment(func) + if not func then + return nil + end + local comments = {} + local params = func:getEmmyParams() + if params then + for _, param in ipairs(params) do + local option = param:getOption() + if option and option.comment then + comments[#comments+1] = ('+ `%s`*(%s)*: %s'):format(param:getName(), param:getType(), option.comment) + end + end + end + comments[#comments+1] = func:getComment() + if #comments == 0 then + return nil + end + return table.concat(comments, '\n\n') +end + +local function getOverLoads(name, func, object, select) + local overloads = func and func:getEmmyOverLoads() + if not overloads then + return nil + end + local list = {} + for _, ol in ipairs(overloads) do + local hover = emmyFunction(name, ol, object, select) + list[#list+1] = hover.label + end + return table.concat(list, '\n') +end + +return function (name, func, object, select) + local argStr, argLabel, args = buildValueArgs(func, object, select) + local returns = buildValueReturns(func) + local enum, rawEnum = buildEnum(func) + local comment = getComment(func) + local overloads = getOverLoads(name, func, object, select) + return { + label = ('function %s(%s)%s'):format(name, argStr, returns), + name = name, + argStr = argStr, + returns = returns, + description = comment, + enum = enum, + rawEnum = rawEnum, + argLabel = argLabel, + overloads = overloads, + args = args, + } +end diff --git a/script/core/hover/hover.lua b/script/core/hover/hover.lua new file mode 100644 index 00000000..2ee5cf46 --- /dev/null +++ b/script/core/hover/hover.lua @@ -0,0 +1,326 @@ +local findLib = require 'core.find_lib' +local getFunctionHover = require 'core.hover.function' +local getFunctionHoverAsLib = require 'core.hover.lib_function' +local getFunctionHoverAsEmmy = require 'core.hover.emmy_function' +local buildValueName = require 'core.hover.name' + +local OriginTypes = { + ['any'] = true, + ['nil'] = true, + ['integer'] = true, + ['number'] = true, + ['boolean'] = true, + ['string'] = true, + ['thread'] = true, + ['userdata'] = true, + ['table'] = true, + ['function'] = true, +} + +local function longString(str) + for i = 0, 10 do + local finish = ']' .. ('='):rep(i) .. ']' + if not str:find(finish, 1, true) then + return ('[%s[\n%s%s'):format(('='):rep(i), str, finish) + end + end + return ('%q'):format(str) +end + +local function formatString(str) + if #str > 1000 then + str = str:sub(1000) + end + if str:find('[\r\n]') then + str = str:gsub('[\000-\008\011-\012\014-\031\127]', '') + return longString(str) + else + str = str:gsub('[\000-\008\011-\012\014-\031\127]', function (char) + return ('\\%03d'):format(char:byte()) + end) + local single = str:find("'", 1, true) + local double = str:find('"', 1, true) + if single and double then + return longString(str) + elseif double then + return ("'%s'"):format(str) + else + return ('"%s"'):format(str) + end + end +end + +local function formatLiteral(v) + if math.type(v) == 'float' then + return ('%.10f'):format(v):gsub('[0]*$', ''):gsub('%.$', '.0') + elseif type(v) == 'string' then + return formatString(v) + else + return ('%q'):format(v) + end +end + +local function findClass(value) + -- 检查是否有emmy + local emmy = value:getEmmy() + if emmy then + return emmy:getType() + end + -- 检查对象元表 + local metaValue = value:getMetaTable() + if not metaValue then + return nil + end + -- 检查元表中的 __name + local metaName = metaValue:rawGet('__name') + if metaName and type(metaName:getLiteral()) == 'string' then + return metaName:getLiteral() + end + -- 检查元表的 __index + local indexValue = metaValue:rawGet('__index') + if not indexValue then + return nil + end + -- 查找index方法中的以下字段: type name class + -- 允许多重继承 + return indexValue:eachChild(function (k, v) + -- 键值类型必须均为字符串 + if type(k) ~= 'string' then + return + end + if type(v:getLiteral()) ~= 'string' then + return + end + local lKey = k:lower() + if lKey == 'type' + or lKey == 'name' + or lKey == 'class' + then + -- 必须只有过一次赋值 + local hasSet = false + local ok = v:eachInfo(function (info) + if info.type == 'set' then + if hasSet then + return false + else + hasSet = true + end + end + end) + if ok == false then + return false + end + return v:getLiteral() + end + end) +end + +local function formatKey(key) + local kType = type(key) + if kType == 'table' then + key = ('[*%s]'):format(key:getType()) + elseif math.type(key) == 'integer' then + key = ('[%03d]'):format(key) + elseif kType == 'string' then + if key:find '^%d' or key:find '[^%w_]' then + key = ('[%s]'):format(formatString(key)) + end + elseif key == '' then + key = '[*any]' + else + key = ('[%s]'):format(key) + end + return key +end + +local function unpackTable(value) + local lines = {} + value:eachChild(function (key, child) + key = formatKey(key) + + local vType = type(child:getLiteral()) + if vType == 'boolean' + or vType == 'integer' + or vType == 'number' + or vType == 'string' + then + lines[#lines+1] = ('%s: %s = %s'):format(key, child:getType(), formatLiteral(child:getLiteral())) + else + lines[#lines+1] = ('%s: %s'):format(key, child:getType()) + end + end) + local emmy = value:getEmmy() + if emmy then + if emmy.type == 'emmy.arrayType' then + lines[#lines+1] = ('[*integer]: %s'):format(emmy:getName()) + elseif emmy.type == 'emmy.tableType' then + lines[#lines+1] = ('[*%s]: %s'):format(emmy:getKeyType():getType(), emmy:getValueType():getType()) + end + end + if #lines == 0 then + return '{}' + end + + -- 整理一下表 + local cleaned = {} + local used = {} + for _, line in ipairs(lines) do + if used[line] then + goto CONTINUE + end + used[line] = true + if line == '[*any]: any' then + goto CONTINUE + end + cleaned[#cleaned+1] = ' ' .. line .. ',' + :: CONTINUE :: + end + + table.sort(cleaned) + table.insert(cleaned, 1, '{') + cleaned[#cleaned+1] = '}' + return table.concat(cleaned, '\r\n') +end + +local function getValueHover(source, name, value, lib) + local valueType = value:getType() + local class = findClass(value) + + if class then + valueType = class + lib = nil + end + + if not OriginTypes[valueType] then + valueType = '*' .. valueType + end + + local tips = {} + local literal + if lib then + literal = lib.code or (lib.value and formatLiteral(lib.value)) + tips[#tips+1] = lib.description + else + literal = value:getLiteral() and formatLiteral(value:getLiteral()) + end + + tips[#tips+1] = value:getComment() + + local tp + if source:bindLocal() then + tp = 'local' + local loc = source:bindLocal() + if loc.tags then + local mark = {} + local tagBufs = {} + for _, tag in ipairs(loc.tags) do + local tagName = tag[1] + if not mark[tagName] then + mark[tagName] = true + tagBufs[#tagBufs+1] = ('<%s>'):format(tagName) + end + end + name = name .. ' ' .. table.concat(tagBufs, ' ') + end + tips[#tips+1] = loc:getComment() + elseif source:get 'global' then + tp = 'global' + elseif source:get 'simple' then + local simple = source:get 'simple' + if simple[1]:get 'global' then + tp = 'global' + else + tp = 'field' + end + else + tp = 'field' + end + + local text + if valueType == 'table' then + text = ('%s %s: %s'):format(tp, name, unpackTable(value)) + else + if literal == nil then + if class and not OriginTypes[class] then + text = ('%s %s: %s %s'):format(tp, name, valueType, unpackTable(value)) + else + text = ('%s %s: %s'):format(tp, name, valueType) + end + else + text = ('%s %s: %s = %s'):format(tp, name, valueType, literal) + end + end + + local tip + if #tips > 0 then + tip = table.concat(tips, '\n\n-------------\n\n') + end + return { + label = text, + description = tip, + } +end + +local function hoverAsValue(source, lsp, select) + local lib, fullkey = findLib(source) + ---@type value + local value = source:findValue() + local name = fullkey or buildValueName(source) + + local hover + if value:getType() == 'function' then + local object = source:get 'object' + if lib then + hover = getFunctionHoverAsLib(name, lib, object, select) + else + local emmy = value:getEmmy() + if emmy and emmy.type == 'emmy.functionType' then + hover = getFunctionHoverAsEmmy(name, emmy, object, select) + else + local func = value:getFunction() + hover = getFunctionHover(name, func, object, select) + end + end + else + hover = getValueHover(source, name, value, lib) + end + + if not hover then + return nil + end + hover.name = name + return hover +end + +local function hoverAsTargetUri(source, lsp) + local uri = source:get 'target uri' + if not lsp or not lsp.workspace then + return nil + end + local path = lsp.workspace:relativePathByUri(uri) + if not path then + return nil + end + return { + description = ('[%s](%s)'):format(path:string(), uri), + } +end + +return function (source, lsp, select) + if not source then + return nil + end + if source:get 'target uri' then + return hoverAsTargetUri(source, lsp) + end + if source.type == 'name' and source:bindValue() then + return hoverAsValue(source, lsp, select) + end + if source.type == 'simple' then + source = source[#source] + if source.type == 'name' and source:bindValue() then + return hoverAsValue(source, lsp, select) + end + end + return nil +end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua new file mode 100644 index 00000000..be5b5632 --- /dev/null +++ b/script/core/hover/init.lua @@ -0,0 +1 @@ +return require 'core.hover.hover' diff --git a/script/core/hover/lib_function.lua b/script/core/hover/lib_function.lua new file mode 100644 index 00000000..06087312 --- /dev/null +++ b/script/core/hover/lib_function.lua @@ -0,0 +1,222 @@ +local lang = require 'language' +local config = require 'config' +local function buildLibArgs(lib, object, select) + if not lib.args then + return '' + end + local start + if object then + start = 2 + else + start = 1 + end + local strs = {} + local args = {} + for i = start, #lib.args do + local arg = lib.args[i] + if arg.optional then + if i > start then + strs[#strs+1] = ' [' + else + strs[#strs+1] = '[' + end + end + if i > start then + strs[#strs+1] = ', ' + end + + local argStr = {} + if i == select then + argStr[#argStr+1] = '@ARG' + end + local name = '' + if arg.name then + name = ('%s: '):format(arg.name) + end + if type(arg.type) == 'table' then + name = name .. table.concat(arg.type, '/') + else + name = name .. (arg.type or 'any') + end + argStr[#argStr+1] = name + args[#args+1] = name + if arg.default then + argStr[#argStr+1] = ('(%q)'):format(arg.default) + end + if i == select then + argStr[#argStr+1] = '@ARG' + end + + for _, str in ipairs(argStr) do + strs[#strs+1] = str + end + if arg.optional == 'self' then + strs[#strs+1] = ']' + end + end + for _, arg in ipairs(lib.args) do + if arg.optional == 'after' then + strs[#strs+1] = ']' + end + end + local text = table.concat(strs) + local argLabel = {} + for i = 1, 2 do + local pos = text:find('@ARG', 1, true) + if pos then + if i == 1 then + argLabel[i] = pos + else + argLabel[i] = pos - 1 + end + text = text:sub(1, pos-1) .. text:sub(pos+4) + end + end + if #argLabel == 0 then + argLabel = nil + end + return text, argLabel, args +end + +local function buildLibReturns(lib) + if not lib.returns then + return '' + end + local strs = {} + for i, rtn in ipairs(lib.returns) do + if rtn.optional then + if i > 1 then + strs[#strs+1] = ' [' + else + strs[#strs+1] = '[' + end + end + if i > 1 then + strs[#strs+1] = ('\n% 3d. '):format(i) + end + if rtn.name then + strs[#strs+1] = ('%s: '):format(rtn.name) + end + if type(rtn.type) == 'table' then + strs[#strs+1] = table.concat(rtn.type, '/') + else + strs[#strs+1] = rtn.type or 'any' + end + if rtn.default then + strs[#strs+1] = ('(%q)'):format(rtn.default) + end + if rtn.optional == 'self' then + strs[#strs+1] = ']' + end + end + for _, rtn in ipairs(lib.returns) do + if rtn.optional == 'after' then + strs[#strs+1] = ']' + end + end + return '\n -> ' .. table.concat(strs) +end + +local function buildEnum(lib) + if not lib.enums then + return '' + end + local container = table.container() + for _, enum in ipairs(lib.enums) do + if not enum.name or (not enum.enum and not enum.code) then + goto NEXT_ENUM + end + if not container[enum.name] then + container[enum.name] = {} + if lib.args then + for _, arg in ipairs(lib.args) do + if arg.name == enum.name then + container[enum.name].type = arg.type + break + end + end + end + if lib.returns then + for _, rtn in ipairs(lib.returns) do + if rtn.name == enum.name then + container[enum.name].type = rtn.type + break + end + end + end + end + table.insert(container[enum.name], enum) + ::NEXT_ENUM:: + end + local strs = {} + local raw = {} + for name, enums in pairs(container) do + local tp + if type(enums.type) == 'table' then + tp = table.concat(enums.type, '/') + else + tp = enums.type + end + strs[#strs+1] = ('\n%s: %s'):format(name, tp or 'any') + raw[name] = {} + for _, enum in ipairs(enums) do + if enum.default then + strs[#strs+1] = '\n -> ' + else + strs[#strs+1] = '\n | ' + end + if enum.code then + strs[#strs+1] = tostring(enum.code) + else + strs[#strs+1] = tostring(enum.enum) + end + raw[name][#raw[name]+1] = strs[#strs] + if enum.description then + strs[#strs+1] = ' -- ' .. enum.description + end + end + end + return table.concat(strs), raw +end + +local function buildDoc(lib) + local doc = lib.doc + if not doc then + return + end + if lib.web then + return lang.script(lib.web, doc) + end + local version = config.config.runtime.version + if version == 'Lua 5.1' then + return lang.script('HOVER_DOCUMENT_LUA51', doc) + elseif version == 'Lua 5.2' then + return lang.script('HOVER_DOCUMENT_LUA52', doc) + elseif version == 'Lua 5.3' then + return lang.script('HOVER_DOCUMENT_LUA53', doc) + elseif version == 'Lua 5.4' then + return lang.script('HOVER_DOCUMENT_LUA54', doc) + elseif version == 'LuaJIT' then + return lang.script('HOVER_DOCUMENT_LUAJIT', doc) + end +end + +return function (name, lib, object, select) + local argStr, argLabel, args = buildLibArgs(lib, object, select) + local returns = buildLibReturns(lib) + local enum, rawEnum = buildEnum(lib) + local tip = lib.description + local doc = buildDoc(lib) + return { + label = ('function %s(%s)%s'):format(name, argStr, returns), + name = name, + argStr = argStr, + returns = returns, + description = tip, + enum = enum, + rawEnum = rawEnum, + argLabel = argLabel, + doc = doc, + args = args, + } +end diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua new file mode 100644 index 00000000..763083b9 --- /dev/null +++ b/script/core/hover/name.lua @@ -0,0 +1,38 @@ +local getName = require 'core.name' + +return function (source) + if not source then + return '' + end + local value = source:bindValue() + if not value then + return '' + end + local func = value:getFunction() + local declarat + if func and func:getSource() then + declarat = func:getSource().name + else + declarat = source + end + if not declarat then + -- 如果声明者没有给名字,则找一个合适的名字 + local names = {} + value:eachInfo(function (info, src) + if info.type == 'local' or info.type == 'set' or info.type == 'return' then + if src.type == 'name' and src.uri == value.uri then + names[#names+1] = src + end + end + end) + if #names == 0 then + return '' + end + table.sort(names, function (a, b) + return a.id < b.id + end) + return names[1][1] or '' + end + + return getName(declarat, source) +end diff --git a/script/core/implementation.lua b/script/core/implementation.lua new file mode 100644 index 00000000..f51a97ca --- /dev/null +++ b/script/core/implementation.lua @@ -0,0 +1,204 @@ +local function parseValueSimily(vm, source, lsp) + local key = source[1] + if not key then + return nil + end + local positions = {} + vm:eachSource(function (other) + if other == source then + return + end + if other[1] == key + and not other:bindLocal() + and other:bindValue() + and other:action() == 'set' + and source:bindValue() ~= other:bindValue() + then + positions[#positions+1] = { + other.start, + other.finish, + } + end + end) + if #positions == 0 then + return nil + end + return positions +end + +local function parseValueCrossFile(vm, source, lsp) + local value = source:bindValue() + local positions = {} + value:eachInfo(function (info, src) + if info.type == 'local' and src.uri == value.uri then + positions[#positions+1] = { + src.start, + src.finish, + value.uri, + } + return true + end + end) + if #positions > 0 then + return positions + end + + value:eachInfo(function (info, src) + if info.type == 'set' and src.uri == value.uri then + positions[#positions+1] = { + src.start, + src.finish, + value.uri, + } + end + end) + if #positions > 0 then + return positions + end + + value:eachInfo(function (info, src) + if info.type == 'return' and src.uri == value.uri then + positions[#positions+1] = { + src.start, + src.finish, + value.uri, + } + end + end) + if #positions > 0 then + return positions + end + + local destVM = lsp:getVM(value.uri) + if not destVM then + positions[#positions+1] = { + 0, 0, value.uri, + } + return positions + end + + local result = parseValueSimily(destVM, source, lsp) + if result then + for _, position in ipairs(result) do + positions[#positions+1] = position + position[3] = value.uri + end + end + if #positions > 0 then + return positions + end + + return positions +end + +local function parseValue(vm, source, lsp) + local positions = {} + local mark = {} + + local function callback(src) + if source == src then + return + end + if mark[src] then + return + end + mark[src] = true + if src.start == 0 then + return + end + local uri = src.uri + if uri == '' then + uri = nil + end + positions[#positions+1] = { + src.start, + src.finish, + uri, + } + end + + if source:bindValue() then + source:bindValue():eachInfo(function (info, src) + if info.type == 'set' or info.type == 'local' or info.type == 'return' then + callback(src) + return true + end + end) + end + local parent = source:get 'parent' + if parent then + parent:eachInfo(function (info, src) + if info[1] == source[1] then + if info.type == 'set child' then + callback(src) + end + end + end) + end + if #positions == 0 then + return nil + end + return positions +end + +local function parseLabel(vm, label, lsp) + local positions = {} + label:eachInfo(function (info, src) + if info.type == 'set' then + positions[#positions+1] = { + src.start, + src.finish, + } + end + end) + if #positions == 0 then + return nil + end + return positions +end + +local function jumpUri(vm, source, lsp) + local uri = source:get 'target uri' + local positions = {} + positions[#positions+1] = { + 0, 0, uri, + } + return positions +end + +local function parseClass(vm, source) + local className = source:get 'emmy class' + local positions = {} + vm.emmyMgr:eachClass(className, function (class) + local src = class:getSource() + positions[#positions+1] = { + src.start, + src.finish, + src.uri, + } + end) + return positions +end + +return function (vm, source, lsp) + if not source then + return nil + end + if source:bindValue() then + return parseValue(vm, source, lsp) + or parseValueSimily(vm, source, lsp) + end + if source:bindLabel() then + return parseLabel(vm, source:bindLabel(), lsp) + end + if source:get 'target uri' then + return jumpUri(vm, source, lsp) + end + if source:get 'in index' then + return parseValue(vm, source, lsp) + or parseValueSimily(vm, source, lsp) + end + if source:get 'emmy class' then + return parseClass(vm, source) + end +end diff --git a/script/core/init.lua b/script/core/init.lua new file mode 100644 index 00000000..213dbaca --- /dev/null +++ b/script/core/init.lua @@ -0,0 +1,19 @@ +local api = { + definition = require 'core.definition', + implementation = require 'core.implementation', + references = require 'core.references', + rename = require 'core.rename', + hover = require 'core.hover', + diagnostics = require 'core.diagnostics', + findSource = require 'core.find_source', + findLib = require 'core.find_lib', + completion = require 'core.completion', + signature = require 'core.signature', + documentSymbol = require 'core.document_symbol', + global = require 'core.global', + highlight = require 'core.highlight', + codeAction = require 'core.code_action', + foldingRange = require 'core.folding_range', +} + +return api diff --git a/script/core/library.lua b/script/core/library.lua new file mode 100644 index 00000000..d5edad66 --- /dev/null +++ b/script/core/library.lua @@ -0,0 +1,296 @@ +local lni = require 'lni' +local fs = require 'bee.filesystem' +local config = require 'config' + +local Library = {} + +local function mergeEnum(lib, locale) + if not lib or not locale then + return + end + local pack = {} + for _, enum in ipairs(lib) do + if enum.enum then + pack[enum.enum] = enum + end + if enum.code then + pack[enum.code] = enum + end + end + for _, enum in ipairs(locale) do + if pack[enum.enum] then + if enum.description then + pack[enum.enum].description = enum.description + end + end + if pack[enum.code] then + if enum.description then + pack[enum.code].description = enum.description + end + end + end +end + +local function mergeField(lib, locale) + if not lib or not locale then + return + end + local pack = {} + for _, field in ipairs(lib) do + if field.field then + pack[field.field] = field + end + end + for _, field in ipairs(locale) do + if pack[field.field] then + if field.description then + pack[field.field].description = field.description + end + end + end +end + +local function mergeLocale(libs, locale) + if not libs or not locale then + return + end + for name in pairs(locale) do + if libs[name] then + if locale[name].description then + libs[name].description = locale[name].description + end + mergeEnum(libs[name].enums, locale[name].enums) + mergeField(libs[name].fields, locale[name].fields) + end + end +end + +local function isMatchVersion(version) + if not version then + return true + end + local runtimeVersion = config.config.runtime.version + if type(version) == 'table' then + for i = 1, #version do + if version[i] == runtimeVersion then + return true + end + end + else + if version == runtimeVersion then + return true + end + end + return false +end + +local function insertGlobal(tbl, key, value) + if not isMatchVersion(value.version) then + return false + end + if not value.doc then + value.doc = key + end + tbl[key] = value + return true +end + +local function insertOther(tbl, key, value) + if not value.version then + return + end + if not tbl[key] then + tbl[key] = {} + end + if type(value.version) == 'string' then + tbl[key][#tbl[key]+1] = value.version + elseif type(value.version) == 'table' then + for _, version in ipairs(value.version) do + if type(version) == 'string' then + tbl[key][#tbl[key]+1] = version + end + end + end + table.sort(tbl[key]) +end + +local function insertCustom(tbl, key, value, libName) + if not tbl[key] then + tbl[key] = {} + end + tbl[key][#tbl[key]+1] = libName + table.sort(tbl[key]) +end + +local function isEnableGlobal(libName) + if config.config.runtime.library[libName] then + return true + end + if libName:sub(1, 1) == '@' then + return true + end + return false +end + +local function mergeSource(alllibs, name, lib, libName) + if not lib.source then + if isEnableGlobal(libName) then + local suc = insertGlobal(alllibs.global, name, lib) + if not suc then + insertOther(alllibs.other, name, lib) + end + else + insertCustom(alllibs.custom, name, lib, libName) + end + return + end + for _, source in ipairs(lib.source) do + local sourceName = source.name or name + if source.type == 'global' then + if isEnableGlobal(libName) then + local suc = insertGlobal(alllibs.global, sourceName, lib) + if not suc then + insertOther(alllibs.other, sourceName, lib) + end + else + insertCustom(alllibs.custom, sourceName, lib, libName) + end + elseif source.type == 'library' then + insertGlobal(alllibs.library, sourceName, lib) + elseif source.type == 'object' then + insertGlobal(alllibs.object, sourceName, lib) + end + end +end + +local function copy(t) + local new = {} + for k, v in pairs(t) do + new[k] = v + end + return new +end + +local function insertChild(tbl, name, key, value) + if not name or not key then + return + end + if not isMatchVersion(value.version) then + return + end + if not value.doc then + value.doc = ('%s.%s'):format(name, key) + end + if not tbl[name] then + tbl[name] = { + type = name, + name = name, + child = {}, + } + end + tbl[name].child[key] = copy(value) +end + +local function mergeParent(alllibs, name, lib, libName) + for _, parent in ipairs(lib.parent) do + if parent.type == 'global' then + if isEnableGlobal(libName) then + insertChild(alllibs.global, parent.name, name, lib) + end + elseif parent.type == 'library' then + insertChild(alllibs.library, parent.name, name, lib) + elseif parent.type == 'object' then + insertChild(alllibs.object, parent.name, name, lib) + end + end +end + +local function mergeLibs(alllibs, libs, libName) + if not libs then + return + end + for _, lib in pairs(libs) do + if lib.parent then + mergeParent(alllibs, lib.name, lib, libName) + else + mergeSource(alllibs, lib.name, lib, libName) + end + end +end + +local function loadLocale(language, relative) + local localePath = ROOT / 'locale' / language / relative + local localeBuf = io.load(localePath) + if localeBuf then + local locale = table.container() + xpcall(lni, log.error, localeBuf, localePath:string(), {locale}) + return locale + end + return nil +end + +local function fix(libs) + for name, lib in pairs(libs) do + lib.name = lib.name or name + lib.child = {} + end +end + +local function scan(path) + local result = {path} + local i = 0 + return function () + i = i + 1 + local current = result[i] + if not current then + return nil + end + if fs.is_directory(current) then + for path in current:list_directory() do + result[#result+1] = path + end + end + return current + end +end + +local function init() + local lang = require 'language' + local id = lang.id + Library.global = table.container() + Library.library = table.container() + Library.object = table.container() + Library.other = table.container() + Library.custom = table.container() + + for libPath in (ROOT / 'libs'):list_directory() do + local enableGlobal + local libName = libPath:filename():string() + for path in scan(libPath) do + local libs + local buf = io.load(path) + if buf then + libs = table.container() + xpcall(lni, log.error, buf, path:string(), {libs}) + fix(libs) + end + local relative = fs.relative(path, ROOT) + + local locale = loadLocale('en-US', relative) + mergeLocale(libs, locale) + if id ~= 'en-US' then + locale = loadLocale(id, relative) + mergeLocale(libs, locale) + end + mergeLibs(Library, libs, libName) + end + end +end + +function Library.reload() + init() +end + +init() + +return Library diff --git a/script/core/matchKey.lua b/script/core/matchKey.lua new file mode 100644 index 00000000..b46250cb --- /dev/null +++ b/script/core/matchKey.lua @@ -0,0 +1,30 @@ +return function (me, other) + if me == other then + return true + end + if me == '' then + return true + end + if #me > #other then + return false + end + local lMe = me:lower() + local lOther = other:lower() + if lMe == lOther:sub(1, #lMe) then + return true + end + local chars = {} + for i = 1, #lOther do + local c = lOther:sub(i, i) + chars[c] = (chars[c] or 0) + 1 + end + for i = 1, #lMe do + local c = lMe:sub(i, i) + if chars[c] and chars[c] > 0 then + chars[c] = chars[c] - 1 + else + return false + end + end + return true +end diff --git a/script/core/name.lua b/script/core/name.lua new file mode 100644 index 00000000..54947974 --- /dev/null +++ b/script/core/name.lua @@ -0,0 +1,70 @@ +return function (source, caller) + if not source then + return '' + end + local key + if source:get 'simple' then + local simple = source:get 'simple' + local chars = {} + for i, obj in ipairs(simple) do + if obj.type == 'name' then + chars[i] = obj[1] + elseif obj.type == 'index' then + chars[i] = '[?]' + elseif obj.type == 'call' then + chars[i] = '(?)' + elseif obj.type == ':' then + chars[i] = ':' + elseif obj.type == '.' then + chars[i] = '.' + else + chars[i] = '*' .. obj.type + end + if obj == source then + break + end + end + key = table.concat(chars) + elseif source.type == 'name' then + key = source[1] + elseif source.type == 'string' then + key = ('%q'):format(source[1]) + elseif source.type == 'number' or source.type == 'boolean' then + key = tostring(source[1]) + elseif source.type == 'simple' then + local chars = {} + for i, obj in ipairs(source) do + if obj.type == 'name' then + chars[i] = obj[1] + elseif obj.type == 'index' then + chars[i] = '[?]' + elseif obj.type == 'call' then + chars[i] = '(?)' + elseif obj.type == ':' then + chars[i] = ':' + elseif obj.type == '.' then + chars[i] = '.' + else + chars[i] = '*' .. obj.type + end + end + -- 这里有个特殊处理 + -- function mt:func() 以 mt.func 的形式调用时 + -- hover 显示为 mt.func(self) + if caller then + if chars[#chars-1] == ':' then + if not caller:get 'object' then + chars[#chars-1] = '.' + end + elseif chars[#chars-1] == '.' then + if caller:get 'object' then + chars[#chars-1] = ':' + end + end + end + key = table.concat(chars) + else + key = '' + end + return key +end diff --git a/script/core/references.lua b/script/core/references.lua new file mode 100644 index 00000000..33b38fec --- /dev/null +++ b/script/core/references.lua @@ -0,0 +1,91 @@ +local findSource = require 'core.find_source' + +local function parseResult(vm, source, declarat, callback) + local isGlobal + if source:bindLabel() then + source:bindLabel():eachInfo(function (info, src) + if (declarat and info.type == 'set') or info.type == 'get' then + callback(src) + end + end) + end + if source:bindLocal() then + local loc = source:bindLocal() + callback(loc:getSource()) + loc:eachInfo(function (info, src) + if (declarat and info.type == 'set') or info.type == 'get' then + callback(src) + end + end) + loc:getValue():eachInfo(function (info, src) + if (declarat and (info.type == 'set' or info.type == 'local' or info.type == 'return')) or info.type == 'get' then + callback(src) + end + end) + end + if source:bindFunction() then + if declarat then + callback(source:bindFunction():getSource()) + end + source:bindFunction():eachInfo(function (info, src) + if (declarat and (info.type == 'set' or info.type == 'local')) or info.type == 'get' then + callback(src) + end + end) + end + if source:bindValue() then + source:bindValue():eachInfo(function (info, src) + if (declarat and (info.type == 'set' or info.type == 'local')) or info.type == 'get' then + callback(src) + end + end) + if source:bindValue():isGlobal() then + isGlobal = true + end + end + local parent = source:get 'parent' + if parent then + parent:eachInfo(function (info, src) + if info[1] == source[1] then + if (declarat and info.type == 'set child') or info.type == 'get child' then + callback(src) + end + end + end) + end + --local emmy = source:getEmmy() + --if emmy then + -- if emmy.type == 'emmy.class' or emmy.type == 'emmy.type' --then +-- + -- end + --end + return isGlobal +end + +return function (vm, pos, declarat) + local source = findSource(vm, pos) + if not source then + return nil + end + local positions = {} + local mark = {} + local isGlobal = parseResult(vm, source, declarat, function (src) + if mark[src] then + return + end + mark[src] = true + if src.start == 0 then + return + end + local uri = src.uri + if uri == '' then + uri = nil + end + positions[#positions+1] = { + src.start, + src.finish, + uri, + } + end) + return positions, isGlobal +end diff --git a/script/core/rename.lua b/script/core/rename.lua new file mode 100644 index 00000000..3a2e8532 --- /dev/null +++ b/script/core/rename.lua @@ -0,0 +1,72 @@ +local findSource = require 'core.find_source' +local parser = require 'parser' + +local function parseResult(source, newName) + local positions = {} + if source:bindLabel() then + if not parser:grammar(newName, 'Name') then + return nil + end + source:bindLabel():eachInfo(function (info, src) + positions[#positions+1] = { src.start, src.finish, src:getUri() } + end) + return positions + end + if source:bindLocal() then + local loc = source:bindLocal() + if loc:get 'hide' then + return nil + end + if source:get 'in index' then + if not parser:grammar(newName, 'Exp') then + return positions + end + else + if not parser:grammar(newName, 'Name') then + return positions + end + end + local mark = {} + loc:eachInfo(function (info, src) + if not mark[src] then + mark[src] = info + positions[#positions+1] = { src.start, src.finish, src:getUri() } + end + end) + return positions + end + if source:bindValue() and source:get 'parent' then + if source:get 'in index' then + if not parser:grammar(newName, 'Exp') then + return positions + end + else + if not parser:grammar(newName, 'Name') then + return positions + end + end + local parent = source:get 'parent' + local mark = {} + parent:eachInfo(function (info, src) + if not mark[src] then + mark[src] = info + if info.type == 'get child' or info.type == 'set child' then + if info[1] == source[1] then + positions[#positions+1] = {src.start, src.finish, src:getUri()} + end + end + end + end) + return positions + end + return nil +end + +return function (vm, pos, newName) + local source = findSource(vm, pos) + if not source then + return nil + end + local positions = parseResult(source, newName) + return positions +end diff --git a/script/core/signature.lua b/script/core/signature.lua new file mode 100644 index 00000000..bbe35ffa --- /dev/null +++ b/script/core/signature.lua @@ -0,0 +1,133 @@ +local getFunctionHover = require 'core.hover.function' +local getFunctionHoverAsLib = require 'core.hover.lib_function' +local getFunctionHoverAsEmmy = require 'core.hover.emmy_function' +local findLib = require 'core.find_lib' +local buildValueName = require 'core.hover.name' +local findSource = require 'core.find_source' + +local function findCall(vm, pos) + local results = {} + vm:eachSource(function (src) + if src.type == 'call' + and src.start <= pos + and src.finish >= pos + then + results[#results+1] = src + end + end) + if #results == 0 then + return nil + end + -- 可能处于 'func1(func2(' 的嵌套中,将最近的call放到最前面 + table.sort(results, function (a, b) + return a.start > b.start + end) + return results +end + +local function getSelect(args, pos) + if not args then + return 1 + end + for i, arg in ipairs(args) do + if arg.start <= pos and arg.finish >= pos - 1 then + return i + end + end + return #args + 1 +end + +local function getFunctionSource(call) + local simple = call:get 'simple' + for i, source in ipairs(simple) do + if source == call then + return simple[i-1] + end + end + return nil +end + +local function getHover(call, pos) + local args = call:bindCall() + if not args then + return nil + end + + local value = call:findCallFunction() + if not value then + return nil + end + + local select = getSelect(args, pos) + local source = getFunctionSource(call) + local object = source:get 'object' + local lib, fullkey = findLib(source) + local name = fullkey or buildValueName(source) + local hover + if lib then + hover = getFunctionHoverAsLib(name, lib, object, select) + else + local emmy = value:getEmmy() + if emmy and emmy.type == 'emmy.functionType' then + hover = getFunctionHoverAsEmmy(name, emmy, object, select) + else + ---@type emmyFunction + local func = value:getFunction() + hover = getFunctionHover(name, func, object, select) + local overLoads = func and func:getEmmyOverLoads() + if overLoads then + for _, ol in ipairs(overLoads) do + hover = getFunctionHoverAsEmmy(name, ol, object, select) + end + end + end + end + return hover +end + +local function isInFunctionOrTable(call, pos) + local args = call:bindCall() + if not args then + return false + end + local select = getSelect(args, pos) + local arg = args[select] + if not arg then + return false + end + if arg.type == 'function' or arg.type == 'table' then + return true + end + return false +end + +return function (vm, pos) + local source = findSource(vm, pos) or findSource(vm, pos-1) + if not source or source.type == 'string' then + return + end + local calls = findCall(vm, pos) + if not calls or #calls == 0 then + return nil + end + + local nearCall = calls[1] + if isInFunctionOrTable(nearCall, pos) then + return nil + end + + local hover = getHover(nearCall, pos) + if not hover then + return nil + end + + -- skip `name(` + local head = #hover.name + 1 + hover.label = ('%s(%s)'):format(hover.name, hover.argStr) + if hover.argLabel then + hover.argLabel[1] = hover.argLabel[1] + head + hover.argLabel[2] = hover.argLabel[2] + head + end + + return { hover } +end diff --git a/script/core/snippet.lua b/script/core/snippet.lua new file mode 100644 index 00000000..7532ce9b --- /dev/null +++ b/script/core/snippet.lua @@ -0,0 +1,64 @@ +local snippet = {} + +local function add(cate, key, label) + return function (text) + if not snippet[cate] then + snippet[cate] = {} + end + if not snippet[cate][key] then + snippet[cate][key] = {} + end + snippet[cate][key][#snippet[cate][key]+1] = { + label = label, + text = text, + } + end +end + +add('key', 'do', 'do .. end') [[ +do + $0 +end]] + +add('key', 'elseif', 'elseif .. then') +[[elseif ${1:true} then]] + +add('key', 'for', 'for .. in') [[ +for ${1:key, value} in ${2:pairs(t)} do + $0 +end]] + +add('key', 'for', 'for i = ..') [[ +for ${1:i} = ${2:1}, ${3:10, 2} do + $0 +end]] + +add('key', 'function', 'function ()') [[ +function $1(${2:arg1, arg2, arg3}) + $0 +end]] + +add('key', 'local', 'local function') [[ +local function ${1:name}(${2:arg1, arg2, arg3}) + $0 +end]] + +add('key', 'if', 'if .. then') [[ +if ${1:true} then + $0 +end]] + +add('key', 'repeat', 'repeat .. until') [[ +repeat + $0 +until ${1:true}]] + +add('key', 'while', 'while .. do') [[ +while ${1:true} do + $0 +end]] + +add('key', 'return', 'do return end') +[[do return ${1:true} end]] + +return snippet |