diff options
author | fesily <fesil@foxmail.com> | 2024-02-21 15:08:38 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-21 15:08:38 +0800 |
commit | 74e74d2a5564f44abdfb31a0a23bf3aa7ec16fed (patch) | |
tree | 67b0df14295e7b3d3f37762781cc0a031965fedc /script | |
parent | 647971cb339d0c3d5aa1da26041ab9ddb71e963b (diff) | |
parent | 3e6fd3ce1f2f0528336ded939d776a29bbfaf2eb (diff) | |
download | lua-language-server-74e74d2a5564f44abdfb31a0a23bf3aa7ec16fed.zip |
Merge branch 'master' into generic-pattern1
Diffstat (limited to 'script')
-rw-r--r-- | script/cli/doc.lua | 19 | ||||
-rw-r--r-- | script/cli/init.lua | 5 | ||||
-rw-r--r-- | script/cli/visualize.lua | 103 | ||||
-rw-r--r-- | script/core/rename.lua | 6 | ||||
-rw-r--r-- | script/core/semantic-tokens.lua | 8 | ||||
-rw-r--r-- | script/encoder/ansi.lua | 12 | ||||
-rw-r--r-- | script/global.d.lua | 4 | ||||
-rw-r--r-- | script/parser/guide.lua | 8 | ||||
-rw-r--r-- | script/parser/luadoc.lua | 86 | ||||
-rw-r--r-- | script/plugin.lua | 57 | ||||
-rw-r--r-- | script/plugins/astHelper.lua | 25 | ||||
-rw-r--r-- | script/plugins/nodeHelper.lua | 75 | ||||
-rw-r--r-- | script/proto/proto.lua | 11 | ||||
-rw-r--r-- | script/vm/compiler.lua | 80 | ||||
-rw-r--r-- | script/vm/infer.lua | 20 | ||||
-rw-r--r-- | script/workspace/scope.lua | 2 |
16 files changed, 457 insertions, 64 deletions
diff --git a/script/cli/doc.lua b/script/cli/doc.lua index 9140a258..fb9b0a8e 100644 --- a/script/cli/doc.lua +++ b/script/cli/doc.lua @@ -169,6 +169,7 @@ local function collectTypes(global, results) field.desc = getDesc(source) field.rawdesc = getDesc(source, true) field.extends = packObject(source.extends) + field.visible = vm.getVisibleType(source) return end if source.type == 'setfield' @@ -187,6 +188,7 @@ local function collectTypes(global, results) field.desc = getDesc(source) field.rawdesc = getDesc(source, true) field.extends = packObject(source.value) + field.visible = vm.getVisibleType(source) return end if source.type == 'tableindex' then @@ -207,6 +209,7 @@ local function collectTypes(global, results) field.desc = getDesc(source) field.rawdesc = getDesc(source, true) field.extends = packObject(source.value) + field.visible = vm.getVisibleType(source) return end end) @@ -245,6 +248,8 @@ local function collectVars(global, results) } result.desc = result.desc or getDesc(set) result.rawdesc = result.rawdesc or getDesc(set, true) + result.defines[#result.defines].extends['desc'] = getDesc(set) + result.defines[#result.defines].extends['rawdesc'] = getDesc(set, true) end end if #result.defines == 0 then @@ -292,6 +297,18 @@ function export.export(outputPath, callback) return docPath, mdPath end +function export.getDocOutputPath() + local doc_output_path = '' + if type(DOC_OUT_PATH) == 'string' then + doc_output_path = fs.absolute(fs.path(DOC_OUT_PATH)):string() + elseif DOC_OUT_PATH == true then + doc_output_path = fs.current_path():string() + else + doc_output_path = LOGPATH + end + return doc_output_path +end + ---@async ---@param outputPath string function export.makeDoc(outputPath) @@ -350,7 +367,7 @@ function export.runCLI() ws.awaitReady(rootUri) await.sleep(0.1) - local docPath, mdPath = export.export(LOGPATH, function (i, max) + local docPath, mdPath = export.export(export.getDocOutputPath(), function (i, max) if os.clock() - lastClock > 0.2 then lastClock = os.clock() local output = '\x0D' diff --git a/script/cli/init.lua b/script/cli/init.lua index b5a9f86d..6d7fc0ff 100644 --- a/script/cli/init.lua +++ b/script/cli/init.lua @@ -12,3 +12,8 @@ if _G['DOC'] then require 'cli.doc' .runCLI() os.exit(0, true) end + +if _G['VISUALIZE'] then + local ret = require 'cli.visualize' .runCLI() + os.exit(ret or 0, true) +end diff --git a/script/cli/visualize.lua b/script/cli/visualize.lua new file mode 100644 index 00000000..29269b82 --- /dev/null +++ b/script/cli/visualize.lua @@ -0,0 +1,103 @@ +local lang = require 'language' +local parser = require 'parser' +local guide = require 'parser.guide' + +local function nodeId(node) + return node.type .. ':' .. node.start .. ':' .. node.finish +end + +local function shorten(str) + if type(str) ~= 'string' then + return str + end + str = str:gsub('\n', '\\\\n') + if #str <= 20 then + return str + else + return str:sub(1, 17) .. '...' + end +end + +local function getTooltipLine(k, v) + if type(v) == 'table' then + if v.type then + v = '<node ' .. v.type .. '>' + else + v = '<table>' + end + end + v = tostring(v) + v = v:gsub('"', '\\"') + return k .. ': ' .. shorten(v) .. '\\n' +end + +local function getTooltip(node) + local str = '' + local skipNodes = {parent = true, start = true, finish = true, type = true} + str = str .. getTooltipLine('start', node.start) + str = str .. getTooltipLine('finish', node.finish) + for k, v in pairs(node) do + if type(k) ~= 'number' and not skipNodes[k] then + str = str .. getTooltipLine(k, v) + end + end + for i = 1, math.min(#node, 15) do + str = str .. getTooltipLine(i, node[i]) + end + if #node > 15 then + str = str .. getTooltipLine('15..' .. #node, '(...)') + end + return str +end + +local nodeEntry = '\t"%s" [\n\t\tlabel="%s\\l%s\\l"\n\t\ttooltip="%s"\n\t]' +local function getNodeLabel(node) + local keyName = guide.getKeyName(node) + if node.type == 'binary' or node.type == 'unary' then + keyName = node.op.type + elseif node.type == 'label' or node.type == 'goto' then + keyName = node[1] + end + return nodeEntry:format(nodeId(node), node.type, shorten(keyName) or '', getTooltip(node)) +end + +local function getVisualizeVisitor(writer) + local function visitNode(node, parent) + if node == nil then return end + writer:write(getNodeLabel(node)) + writer:write('\n') + if parent then + writer:write(('\t"%s" -> "%s"'):format(nodeId(parent), nodeId(node))) + writer:write('\n') + end + guide.eachChild(node, function(child) + visitNode(child, node) + end) + end + return visitNode +end + + +local export = {} + +function export.visualizeAst(code, writer) + local state = parser.compile(code, 'Lua', _G['LUA_VER'] or 'Lua 5.4') + writer:write('digraph AST {\n') + writer:write('\tnode [shape = rect]\n') + getVisualizeVisitor(writer)(state.ast) + writer:write('}\n') +end + +function export.runCLI() + lang(LOCALE) + local file = _G['VISUALIZE'] + local code, err = io.open(file) + if not code then + io.stderr:write('failed to open ' .. file .. ': ' .. err) + return 1 + end + code = code:read('a') + return export.visualizeAst(code, io.stdout) +end + +return export diff --git a/script/core/rename.lua b/script/core/rename.lua index cc5d37f3..534a972a 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -414,7 +414,11 @@ function m.rename(uri, pos, newname) return end mark[uid] = true - if files.isLibrary(turi, true) then + if vm.isMetaFile(turi) then + return + end + if files.isLibrary(turi, true) + and not files.isLibrary(uri, true) then return end results[#results+1] = { diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index 4e1d8e00..e908ef7b 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -882,6 +882,10 @@ return function (uri, start, finish) local n = 0 guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async + -- skip virtual source + if source.virtual then + return + end Care(source.type, source, options, results) n = n + 1 if n % 100 == 0 then @@ -890,6 +894,10 @@ return function (uri, start, finish) end) for _, comm in ipairs(state.comms) do + -- skip virtual comment + if comm.virtual then + return + end if start <= comm.start and comm.finish <= finish then local headPos = (comm.type == 'comment.short' and comm.text:match '^%-%s*[@|]()') or (comm.type == 'comment.long' and comm.text:match '^@()') diff --git a/script/encoder/ansi.lua b/script/encoder/ansi.lua index f5273c51..7cb64ec3 100644 --- a/script/encoder/ansi.lua +++ b/script/encoder/ansi.lua @@ -1,24 +1,24 @@ local platform = require 'bee.platform' -local unicode +local windows if platform.OS == 'Windows' then - unicode = require 'bee.unicode' + windows = require 'bee.windows' end local m = {} function m.toutf8(text) - if not unicode then + if not windows then return text end - return unicode.a2u(text) + return windows.a2u(text) end function m.fromutf8(text) - if not unicode then + if not windows then return text end - return unicode.u2a(text) + return windows.u2a(text) end return m diff --git a/script/global.d.lua b/script/global.d.lua index b44d6371..cee9e01b 100644 --- a/script/global.d.lua +++ b/script/global.d.lua @@ -52,6 +52,10 @@ CHECK = '' ---@type string DOC = '' +--output directory path for documentation (doc.json, ...) +---@type string +DOC_OUT_PATH = '' + ---@type string | '"Error"' | '"Warning"' | '"Information"' | '"Hint"' CHECKLEVEL = 'Warning' diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 4e71c832..fd779da0 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -10,7 +10,7 @@ local type = type ---@field type string ---@field special string ---@field tag string ----@field args { [integer]: parser.object, start: integer, finish: integer } +---@field args { [integer]: parser.object, start: integer, finish: integer, type: string } ---@field locals parser.object[] ---@field returns? parser.object[] ---@field breaks? parser.object[] @@ -1313,12 +1313,18 @@ end function m.getParams(source) if source.type == 'call' then local args = source.args + if not args then + return + end assert(args.type == 'callargs', 'call.args type is\'t callargs') return args elseif source.type == 'callargs' then return source elseif source.type == 'function' then local args = source.args + if not args then + return + end assert(args.type == 'funcargs', 'function.args type is\'t callargs') return args end diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index 47396df6..edbfd34e 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -369,6 +369,78 @@ local function parseTable(parent) return typeUnit end +local function parseTuple(parent) + if not checkToken('symbol', '[', 1) then + return nil + end + nextToken() + local typeUnit = { + type = 'doc.type.table', + start = getStart(), + parent = parent, + fields = {}, + isTuple = true, + } + + local index = 1 + while true do + if checkToken('symbol', ']', 1) then + nextToken() + break + end + local field = { + type = 'doc.type.field', + parent = typeUnit, + } + + do + local needCloseParen + if checkToken('symbol', '(', 1) then + nextToken() + needCloseParen = true + end + field.name = { + type = 'doc.type', + start = getFinish(), + firstFinish = getFinish(), + finish = getFinish(), + parent = field, + } + field.name.types = { + [1] = { + type = 'doc.type.integer', + start = getFinish(), + finish = getFinish(), + parent = field.name, + [1] = index, + } + } + index = index + 1 + field.extends = parseType(field) + if not field.extends then + break + end + field.optional = field.extends.optional + field.start = field.extends.start + field.finish = field.extends.finish + if needCloseParen then + nextSymbolOrError ')' + end + end + + typeUnit.fields[#typeUnit.fields+1] = field + if checkToken('symbol', ',', 1) + or checkToken('symbol', ';', 1) then + nextToken() + else + nextSymbolOrError(']') + break + end + end + typeUnit.finish = getFinish() + return typeUnit +end + local function parseSigns(parent) if not checkToken('symbol', '<', 1) then return nil @@ -729,6 +801,7 @@ end function parseTypeUnit(parent) local result = parseFunction(parent) or parseTable(parent) + or parseTuple(parent) or parseString(parent) or parseCode(parent) or parseInteger(parent) @@ -912,6 +985,7 @@ local docSwitch = util.switch() while true do local extend = parseName('doc.extends.name', result) or parseTable(result) + or parseTuple(result) if not extend then pushWarning { type = 'LUADOC_MISS_CLASS_EXTENDS_NAME', @@ -2050,7 +2124,10 @@ local function bindDocs(state) state.ast.docs.groups[#state.ast.docs.groups+1] = binded end binded[#binded+1] = doc - if isTailComment(text, doc) then + if doc.specialBindGroup then + bindDocWithSources(sources, doc.specialBindGroup) + binded = nil + elseif isTailComment(text, doc) and doc.type ~= "doc.class" and doc.type ~= "doc.field" then bindDocWithSources(sources, binded) binded = nil else @@ -2154,11 +2231,13 @@ local function luadoc(state) end end end - + if ast.state.pluginDocs then for i, doc in ipairs(ast.state.pluginDocs) do insertDoc(doc, doc.originalComment) end + ---@param a unknown + ---@param b unknown table.sort(ast.docs, function (a, b) return a.start < b.start end) @@ -2176,7 +2255,7 @@ local function luadoc(state) end return { - buildAndBindDoc = function (ast, src, comment) + buildAndBindDoc = function (ast, src, comment, group) local doc = buildLuaDoc(comment) if doc then local pluginDocs = ast.state.pluginDocs or {} @@ -2184,6 +2263,7 @@ return { doc.special = src doc.originalComment = comment doc.virtual = true + doc.specialBindGroup = group ast.state.pluginDocs = pluginDocs return doc end diff --git a/script/plugin.lua b/script/plugin.lua index 7a661e0d..b297cd9b 100644 --- a/script/plugin.lua +++ b/script/plugin.lua @@ -7,6 +7,15 @@ local scope = require 'workspace.scope' local ws = require 'workspace' local fs = require 'bee.filesystem' +---@class pluginInterfaces +local pluginConfigs = { + -- create plugin for vm module + VM = { + OnCompileFunctionParam = function (next, func, source) + end + } +} + ---@class plugin local m = {} @@ -51,6 +60,15 @@ function m.dispatch(event, uri, ...) return failed == 0, res1, res2 end +function m.getVmPlugin(uri) + local scp = scope.getScope(uri) + local interfaces = scp:get('pluginInterfaces') + if not interfaces then + return + end + return interfaces.VM +end + ---@async ---@param scp scope local function checkTrustLoad(scp) @@ -78,6 +96,40 @@ local function checkTrustLoad(scp) return true end +local function createMethodGroup(interfaces, key, methods) + local methodGroup = {} + + for method in pairs(methods) do + local funcs = setmetatable({}, { + __call = function (t, next, ...) + if #t == 0 then + return next(...) + else + local result + for _, fn in ipairs(t) do + result = fn(next, ...) + end + return result + end + end + }) + for _, interface in ipairs(interfaces) do + local func = interface[method] + if not func then + local namespace = interface[key] + if namespace then + func = namespace[method] + end + end + if func then + funcs[#funcs+1] = func + end + end + methodGroup[method] = funcs + end + return #methodGroup>0 and methodGroup or nil +end + ---@param uri uri local function initPlugin(uri) await.call(function () ---@async @@ -148,6 +200,11 @@ local function initPlugin(uri) end interfaces[#interfaces+1] = interface end + + for key, config in pairs(pluginConfigs) do + interfaces[key] = createMethodGroup(interfaces, key, config) + end + ws.resetFiles(scp) end) end diff --git a/script/plugins/astHelper.lua b/script/plugins/astHelper.lua index aba09478..bfe2dd27 100644 --- a/script/plugins/astHelper.lua +++ b/script/plugins/astHelper.lua @@ -23,14 +23,27 @@ end ---@param ast parser.object ---@param source parser.object local/global variable ---@param classname string -function _M.addClassDoc(ast, source, classname) +---@param group table? +function _M.addClassDoc(ast, source, classname, group) + return _M.addDoc(ast, source, "class", classname, group) +end + +--- give the local/global variable a luadoc comment +---@param ast parser.object +---@param source parser.object local/global variable +---@param key string +---@param value string +---@param group table? +function _M.addDoc(ast, source, key, value, group) if source.type ~= 'local' and not guide.isGlobal(source) then return false end - --TODO fileds - --TODO callers - local comment = _M.buildComment("class", classname, source.start - 1) - return luadoc.buildAndBindDoc(ast, source, comment) + local comment = _M.buildComment(key, value, source.start - 1) + local doc = luadoc.buildAndBindDoc(ast, source, comment, group) + if group then + group[#group+1] = doc + end + return doc end ---remove `ast` function node `index` arg, the variable will be the function local variable @@ -57,7 +70,7 @@ end function _M.addClassDocAtParam(ast, classname, source, index) local arg = _M.removeArg(source, index) if arg then - return _M.addClassDoc(ast, arg, classname), arg + return not not _M.addClassDoc(ast, arg, classname), arg end return false end diff --git a/script/plugins/nodeHelper.lua b/script/plugins/nodeHelper.lua new file mode 100644 index 00000000..3f90b152 --- /dev/null +++ b/script/plugins/nodeHelper.lua @@ -0,0 +1,75 @@ +local vm = require 'vm' +local guide = require 'parser.guide' + +local _M = {} + +---@class node.match.pattern +---@field next node.match.pattern? + +local function deepCompare(source, pattern) + local type1, type2 = type(source), type(pattern) + if type1 ~= type2 then + return false + end + + if type1 ~= "table" then + return source == pattern + end + + for key2, value2 in pairs(pattern) do + local value1 = source[key2] + if value1 == nil or not deepCompare(value1, value2) then + return false + end + end + + return true +end + +---@param source parser.object +---@param pattern node.match.pattern +---@return boolean +function _M.matchPattern(source, pattern) + if source.type == 'local' then + if source.parent.type == 'funcargs' and source.parent.parent.type == 'function' then + for i, ref in ipairs(source.ref) do + if deepCompare(ref, pattern) then + return true + end + end + end + end + return false +end + +local vaildVarRegex = "()([a-zA-Z][a-zA-Z0-9_]*)()" +---创建类型 *.field.field形式的 pattern +---@param pattern string +---@return node.match.pattern?, string? +function _M.createFieldPattern(pattern) + local ret = { next = nil } + local next = ret + local init = 1 + while true do + local startpos, matched, endpos + if pattern:sub(1, 1) == "*" then + startpos, matched, endpos = init, "*", init + 1 + else + startpos, matched, endpos = vaildVarRegex:match(pattern, init) + end + if not startpos then + break + end + if startpos ~= init then + return nil, "invalid pattern" + end + local field = matched == "*" and { next = nil } + or { field = { type = 'field', matched }, type = 'getfield', next = nil } + next.next = field + next = field + pattern = pattern:sub(endpos) + end + return ret +end + +return _M diff --git a/script/proto/proto.lua b/script/proto/proto.lua index d01c8f36..2460b4ec 100644 --- a/script/proto/proto.lua +++ b/script/proto/proto.lua @@ -1,5 +1,3 @@ -local subprocess = require 'bee.subprocess' -local socket = require 'bee.socket' local util = require 'utility' local await = require 'await' local pub = require 'pub' @@ -7,7 +5,7 @@ local jsonrpc = require 'jsonrpc' local define = require 'proto.define' local json = require 'json' local inspect = require 'inspect' -local thread = require 'bee.thread' +local platform = require 'bee.platform' local fs = require 'bee.filesystem' local net = require 'service.net' local timer = require 'timer' @@ -234,8 +232,11 @@ end function m.listen(mode, socketPort) m.mode = mode if mode == 'stdio' then - subprocess.filemode(io.stdin, 'b') - subprocess.filemode(io.stdout, 'b') + if platform.OS == 'Windows' then + local windows = require 'bee.windows' + windows.filemode(io.stdin, 'b') + windows.filemode(io.stdout, 'b') + end io.stdin:setvbuf 'no' io.stdout:setvbuf 'no' pub.task('loadProtoByStdio') diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 2222fa9b..2253c83a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -5,6 +5,7 @@ local rpath = require 'workspace.require-path' local files = require 'files' ---@class vm local vm = require 'vm.vm' +local plugin = require 'plugin' ---@class parser.object ---@field _compiledNodes boolean @@ -1031,6 +1032,55 @@ local function compileForVars(source, target) end ---@param source parser.object +local function compileFunctionParam(func, source) + -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number + local funcNode = vm.compileNode(func) + for n in funcNode:eachObject() do + if n.type == 'doc.type.function' then + for index, arg in ipairs(n.args) do + if func.args[index] == source then + local argNode = vm.compileNode(arg) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + end + end + return true + end + end + end + end + if func.parent.type == 'local' then + local refs = func.parent.ref + local findCall + if refs then + for i, ref in ipairs(refs) do + if ref.parent.type == 'call' then + findCall = ref.parent + break + end + end + end + if findCall and findCall.args then + local index + for i, arg in ipairs(source.parent) do + if arg == source then + index = i + break + end + end + if index then + local callerArg = findCall.args[index] + if callerArg then + vm.setNode(source, vm.compileNode(callerArg)) + return true + end + end + end + end +end + +---@param source parser.object local function compileLocal(source) local myNode = vm.setNode(source, source) @@ -1069,7 +1119,6 @@ local function compileLocal(source) vm.setNode(source, vm.compileNode(setfield.node)) end end - if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then local func = source.parent.parent -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number @@ -1090,35 +1139,6 @@ local function compileLocal(source) end end end - if not hasDocArg - and func.parent.type == 'local' then - local refs = func.parent.ref - local findCall - if refs then - for i, ref in ipairs(refs) do - if ref.parent.type == 'call' then - findCall = ref.parent - break - end - end - end - if findCall and findCall.args then - local index - for i, arg in ipairs(source.parent) do - if arg == source then - index = i - break - end - end - if index then - local callerArg = findCall.args[index] - if callerArg then - hasDocArg = true - vm.setNode(source, vm.compileNode(callerArg)) - end - end - end - end if not hasDocArg then vm.setNode(source, vm.declareGlobal('type', 'any')) end diff --git a/script/vm/infer.lua b/script/vm/infer.lua index f2673ed3..3f3d0e3a 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -157,22 +157,24 @@ local viewNodeSwitch;viewNodeSwitch = util.switch() end infer._hasClass = true local buf = {} - buf[#buf+1] = '{ ' + buf[#buf+1] = source.isTuple and '[' or '{ ' for i, field in ipairs(source.fields) do if i > 1 then buf[#buf+1] = ', ' end - local key = field.name - if key.type == 'doc.type' then - buf[#buf+1] = ('[%s]: '):format(vm.getInfer(key):view(uri)) - elseif type(key[1]) == 'string' then - buf[#buf+1] = key[1] .. ': ' - else - buf[#buf+1] = ('[%q]: '):format(key[1]) + if not source.isTuple then + local key = field.name + if key.type == 'doc.type' then + buf[#buf+1] = ('[%s]: '):format(vm.getInfer(key):view(uri)) + elseif type(key[1]) == 'string' then + buf[#buf+1] = key[1] .. ': ' + else + buf[#buf+1] = ('[%q]: '):format(key[1]) + end end buf[#buf+1] = vm.getInfer(field.extends):view(uri) end - buf[#buf+1] = ' }' + buf[#buf+1] = source.isTuple and ']' or ' }' return table.concat(buf) end) : case 'doc.type.string' diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua index da72a1eb..789b5f81 100644 --- a/script/workspace/scope.lua +++ b/script/workspace/scope.lua @@ -125,8 +125,6 @@ function mt:set(k, v) return v end ----@param k string ----@return any function mt:get(k) return self._data[k] end |