diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2021-08-05 21:05:21 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2021-08-05 21:05:21 +0800 |
commit | caec7495dc9001a110642b2d5eb5a8d7fef76dd3 (patch) | |
tree | 48ab4d514bddd4e42927109465e30663facf4f04 /script/core | |
parent | 5fc7a3bbf37751f70db4f36ee566c0ef76e15bb4 (diff) | |
parent | d5ed4fb29ef4d1275349bab60f5f5e9500dba4d2 (diff) | |
download | lua-language-server-caec7495dc9001a110642b2d5eb5a8d7fef76dd3.zip |
Merge branch 'performance'
Diffstat (limited to 'script/core')
-rw-r--r-- | script/core/command/removeSpace.lua | 2 | ||||
-rw-r--r-- | script/core/completion.lua | 19 | ||||
-rw-r--r-- | script/core/definition.lua | 6 | ||||
-rw-r--r-- | script/core/diagnostics/deprecated.lua | 31 | ||||
-rw-r--r-- | script/core/diagnostics/duplicate-doc-class.lua | 2 | ||||
-rw-r--r-- | script/core/diagnostics/init.lua | 2 | ||||
-rw-r--r-- | script/core/generic.lua | 16 | ||||
-rw-r--r-- | script/core/hover/label.lua | 3 | ||||
-rw-r--r-- | script/core/infer.lua | 8 | ||||
-rw-r--r-- | script/core/keyword.lua | 4 | ||||
-rw-r--r-- | script/core/noder.lua | 1147 | ||||
-rw-r--r-- | script/core/reference.lua | 6 | ||||
-rw-r--r-- | script/core/searcher.lua | 945 | ||||
-rw-r--r-- | script/core/type-definition.lua | 6 |
14 files changed, 1257 insertions, 940 deletions
diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua index 2ef178f3..b94f9788 100644 --- a/script/core/command/removeSpace.lua +++ b/script/core/command/removeSpace.lua @@ -26,7 +26,7 @@ return function (data) local pos = line:find '[ \t]+$' if pos then local start, finish = guide.lineRange(lines, i, true) - start = start + pos - 1 + start = start + pos if isInString(ast, start) then goto NEXT_LINE end diff --git a/script/core/completion.lua b/script/core/completion.lua index d5cca4e0..26776634 100644 --- a/script/core/completion.lua +++ b/script/core/completion.lua @@ -636,22 +636,11 @@ local function checkCommon(myUri, word, text, offset, results) if myUri and files.eq(myUri, uri) then goto CONTINUE end - local cache = files.getCache(uri) - if not cache.commonWords then - cache.commonWords = {} - local mark = {} - for str in files.getText(uri):gmatch '([%a_][%w_]+)' do - if #str >= 3 and not mark[str] then - mark[str] = true - local head = str:sub(1, 2) - if not cache.commonWords[head] then - cache.commonWords[head] = {} - end - cache.commonWords[head][#cache.commonWords[head]+1] = str - end - end + local words = files.getWordsOfHead(uri, myHead) + if not words then + goto CONTINUE end - for _, str in ipairs(cache.commonWords[myHead] or {}) do + for _, str in ipairs(words) do if #results >= 100 then break end diff --git a/script/core/definition.lua b/script/core/definition.lua index b6b45871..fb74b73a 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -8,8 +8,8 @@ local guide = require 'parser.guide' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = searcher.getUri(a.target) - local u2 = searcher.getUri(b.target) + local u1 = guide.getUri(a.target) + local u2 = guide.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -21,7 +21,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = searcher.getUri(res) + local uri = guide.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua index 22bf70c0..0aeac9e9 100644 --- a/script/core/diagnostics/deprecated.lua +++ b/script/core/diagnostics/deprecated.lua @@ -7,6 +7,7 @@ local define = require 'proto.define' local await = require 'await' local noder = require 'core.noder' +local types = {'getglobal', 'getfield', 'getindex', 'getmethod'} return function (uri, callback) local ast = files.getState(uri) if not ast then @@ -15,13 +16,7 @@ return function (uri, callback) local cache = {} - guide.eachSource(ast.ast, function (src) - if src.type ~= 'getglobal' - and src.type ~= 'getfield' - and src.type ~= 'getindex' - and src.type ~= 'getmethod' then - return - end + guide.eachSourceTypes(ast.ast, types, function (src) if src.type == 'getglobal' then local key = src[1] if not key then @@ -40,17 +35,29 @@ return function (uri, callback) return end - if cache[id] then + if cache[id] == false then return end + if cache[id] then + callback { + start = src.start, + finish = src.finish, + tags = { define.DiagnosticTag.Deprecated }, + message = cache[id].message, + data = cache[id].data, + } + end + await.delay() if not vm.isDeprecated(src, true) then - cache[id] = true + cache[id] = false return end + await.delay() + local defs = vm.getDefs(src) local validVersions for _, def in ipairs(defs) do @@ -78,6 +85,12 @@ return function (uri, callback) message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(versions, '/'), config.get 'Lua.runtime.version')) end end + cache[id] = { + message = message, + data = { + versions = versions, + }, + } callback { start = src.start, diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua index 20eedb5e..780d15b9 100644 --- a/script/core/diagnostics/duplicate-doc-class.lua +++ b/script/core/diagnostics/duplicate-doc-class.lua @@ -28,7 +28,7 @@ return function (uri, callback) cache[name][#cache[name]+1] = { start = otherDoc.start, finish = otherDoc.finish, - uri = searcher.getUri(otherDoc), + uri = guide.getUri(otherDoc), } end end diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua index 09688f6e..b44d6a2c 100644 --- a/script/core/diagnostics/init.lua +++ b/script/core/diagnostics/init.lua @@ -57,7 +57,7 @@ local function check(uri, name, results) end, name) local passed = os.clock() - clock if passed >= 0.5 then - log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed)) + log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, files.getOriginUri(uri), passed)) end if DIAGTIMES then DIAGTIMES[name] = (DIAGTIMES[name] or 0) + passed diff --git a/script/core/generic.lua b/script/core/generic.lua index f5ede5d7..ce957a71 100644 --- a/script/core/generic.lua +++ b/script/core/generic.lua @@ -141,10 +141,18 @@ local function createValue(closure, proto, callback, road) return value end if proto.type == 'doc.type.field' then - road[#road+1] = ('%s%q'):format( - noder.SPLIT_CHAR, - proto.name[1] - ) + local name = proto.name[1] + if type(name) == 'string' then + road[#road+1] = ('%s%s'):format( + noder.STRING_FIELD, + name + ) + else + road[#road+1] = ('%s%s'):format( + noder.SPLIT_CHAR, + name + ) + end local typeUnit = createValue(closure, proto.extends, callback, road) road[#road] = nil if not typeUnit then diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index 434eaff1..0d2bcf6f 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -9,6 +9,7 @@ local searcher = require 'core.searcher' local lang = require 'language' local config = require 'config' local files = require 'files' +local guide = require 'parser.guide' local function asFunction(source, oop) local name @@ -164,7 +165,7 @@ local function asNumber(source) if type(num) ~= 'number' then return nil end - local uri = searcher.getUri(source) + local uri = guide.getUri(source) local text = files.getText(uri) if not text then return nil diff --git a/script/core/infer.lua b/script/core/infer.lua index e1e0cd57..c2912e5d 100644 --- a/script/core/infer.lua +++ b/script/core/infer.lua @@ -534,11 +534,9 @@ function m.searchInfers(source, field, mark) searchInfer(source, infers, mark) local id = noder.getID(source) if id then - local node = noder.getNodeByID(source, id) - if node and node.source then - for src in noder.eachSource(node) do - searchInfer(src, infers, mark) - end + local noders = noder.getNoders(source) + for src in noder.eachSource(noders, id) do + searchInfer(src, infers, mark) end end if source.type == 'field' or source.type == 'method' then diff --git a/script/core/keyword.lua b/script/core/keyword.lua index 73892f18..b8e37605 100644 --- a/script/core/keyword.lua +++ b/script/core/keyword.lua @@ -277,8 +277,8 @@ until $1" or first == 'elseif' then local startRow = guide.positionOf(lines, info.start) local finishRow = guide.positionOf(lines, pos) - local startSp = info.text:match('^%s*', lines[startRow].start) - local finishSp = info.text:match('^%s*', lines[finishRow].start) + local startSp = info.text:match('^%s*', lines[startRow].start + 1) + local finishSp = info.text:match('^%s*', lines[finishRow].start + 1) if startSp == finishSp then return false end diff --git a/script/core/noder.lua b/script/core/noder.lua index 0ae48965..92eea166 100644 --- a/script/core/noder.lua +++ b/script/core/noder.lua @@ -3,47 +3,61 @@ local guide = require 'parser.guide' local collector = require 'core.collector' local files = require 'files' +local tostring = tostring +local error = error +local ipairs = ipairs +local type = type +local next = next +local log = log +local ssub = string.sub +local sformat = string.format +local sgsub = string.gsub +local smatch = string.match + +_ENV = nil + local SPLIT_CHAR = '\x1F' local LAST_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']*$' local FIRST_REGEX = '^[^' .. SPLIT_CHAR .. ']*' -local HEAD_REGEX = '^' .. SPLIT_CHAR .. '?[^' .. SPLIT_CHAR .. ']*' +local HEAD_REGEX = '^' .. SPLIT_CHAR .. '?[^' .. SPLIT_CHAR .. ']*' +local STRING_CHAR = '.' local ANY_FIELD_CHAR = '*' local INDEX_CHAR = '[' local RETURN_INDEX = SPLIT_CHAR .. '#' local PARAM_INDEX = SPLIT_CHAR .. '&' local TABLE_KEY = SPLIT_CHAR .. '<' local WEAK_TABLE_KEY = SPLIT_CHAR .. '<<' +local STRING_FIELD = SPLIT_CHAR .. STRING_CHAR local INDEX_FIELD = SPLIT_CHAR .. INDEX_CHAR local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR local WEAK_ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR .. ANY_FIELD_CHAR local URI_CHAR = '@' local URI_REGEX = URI_CHAR .. '([^' .. URI_CHAR .. ']*)' .. URI_CHAR .. '(.*)' ----@class node --- 当前节点的id ----@field id string +---@alias node.id string +---@alias node.filter fun(id: string, field?: string):boolean + +---@class noders -- 使用该ID的单元 ----@field source parser.guide.object +---@field source table<node.id, parser.guide.object> -- 使用该ID的单元 ----@field sources parser.guide.object[] --- 前进的关联ID ----@field forward string --- 第一个前进关联的tag ----@field finfo? node.info +---@field sources table<node.id, parser.guide.object[]> -- 前进的关联ID ----@field forwards string[] +---@field forward table<node.id, node.id> +-- 第一个前进关联的info +---@field finfo? table<node.id, node.info> +-- 前进的关联ID与info +---@field forwards table<node.id, node.id[]|table<node.id, node.info>> -- 后退的关联ID ----@field backward string --- 第一个后退关联的tag ----@field binfo? node.info --- 后退的关联ID ----@field backwards string[] +---@field backward table<node.id, node.id> +-- 第一个后退关联的info +---@field binfo? table<node.id, node.info> +-- 后退的关联ID与info +---@field backwards table<node.id, node.id[]|table<node.id, node.info>> -- 函数调用参数信息(用于泛型) ----@field call parser.guide.object ----@field skip boolean - ----@alias noders table<string, node[]> ----@alias node.filter fun(id: string, field?: string):boolean +---@field call table<node.id, parser.guide.object> +---@field require table<node.id, string> +---@field skip table<node.id, boolean> ---@class node.info ---@field reject? string @@ -52,19 +66,6 @@ local URI_REGEX = URI_CHAR .. '([^' .. URI_CHAR .. ']*)' .. URI_CHAR .. '(. ---@field filterValid? node.filter ---@field dontCross? boolean ----创建source的链接信息 ----@param noders noders ----@param id string ----@return node -local function getNode(noders, id) - if not noders[id] then - noders[id] = { - id = id, - } - end - return noders[id] -end - ---如果对象是 arg self, 则认为 id 是 method 的 node ---@param source parser.guide.object ---@return nil @@ -92,95 +93,137 @@ local function getMethodNode(source) end end ----获取语法树单元的key ----@param source parser.guide.object ----@return string? key ----@return parser.guide.object? node -local function getKey(source) - if source.type == 'local' then +local getKey +local getKeyMap = util.switch() + : case 'local' + : call(function (source) if source.parent.type == 'funcargs' then return 'p:' .. source.start, nil end return 'l:' .. source.start, nil - elseif source.type == 'setlocal' - or source.type == 'getlocal' then + end) + : case 'setlocal' + : case 'getlocal' + : call(function (source) return getKey(source.node) - elseif source.type == 'setglobal' - or source.type == 'getglobal' then + end) + : case 'setglobal' + : case 'getglobal' + : call(function (source) local node = source.node if node.tag == '_ENV' then - return ('%q'):format(source[1] or ''), nil + return STRING_CHAR .. (source[1] or ''), nil else - return ('%q'):format(source[1] or ''), node - end - elseif source.type == 'getfield' - or source.type == 'setfield' then - return ('%q'):format(source.field and source.field[1] or ''), source.node - elseif source.type == 'tablefield' then - return ('%q'):format(source.field and source.field[1] or ''), source.parent - elseif source.type == 'getmethod' - or source.type == 'setmethod' then - return ('%q'):format(source.method and source.method[1] or ''), source.node - elseif source.type == 'setindex' - or source.type == 'getindex' then + return STRING_CHAR .. (source[1] or ''), node + end + end) + : case 'getfield' + : case 'setfield' + : call(function (source) + return STRING_CHAR .. (source.field and source.field[1] or ''), source.node + end) + : case 'tablefield' + : call(function (source) + return STRING_CHAR .. (source.field and source.field[1] or ''), source.parent + end) + : case 'getmethod' + : case 'setmethod' + : call(function (source) + return STRING_CHAR .. (source.method and source.method[1] or ''), source.node + end) + : case 'setindex' + : case 'getindex' + : call(function (source) local index = source.index if not index then return INDEX_CHAR, source.node end - if index.type == 'string' - or index.type == 'boolean' - or index.type == 'integer' - or index.type == 'number' then - return ('%q'):format(index[1] or ''), source.node + if index.type == 'string' then + return STRING_CHAR .. (index[1] or ''), source.node + elseif index.type == 'boolean' + or index.type == 'integer' + or index.type == 'number' then + return tostring(index[1] or ''), source.node else return INDEX_CHAR, source.node end - elseif source.type == 'tableindex' then + end) + : case 'tableindex' + : call(function (source) local index = source.index if not index then return ANY_FIELD_CHAR, source.parent end - if index.type == 'string' - or index.type == 'boolean' - or index.type == 'integer' - or index.type == 'number' then - return ('%q'):format(index[1] or ''), source.parent + if index.type == 'string' then + return STRING_CHAR .. (index[1] or ''), source.parent + elseif index.type == 'boolean' + or index.type == 'integer' + or index.type == 'number' then + return tostring(index[1] or ''), source.parent elseif index.type ~= 'function' and index.type ~= 'table' then return ANY_FIELD_CHAR, source.parent end - elseif source.type == 'tableexp' then + end) + : case 'tableexp' + : call(function (source) return tostring(source.tindex), source.parent - elseif source.type == 'table' then + end) + : case 'table' + : call(function (source) return 't:' .. source.start, nil - elseif source.type == 'label' then + end) + : case 'label' + : call(function (source) return 'l:' .. source.start, nil - elseif source.type == 'goto' then + end) + : case 'goto' + : call(function (source) if source.node then return 'l:' .. source.node.start, nil end return nil, nil - elseif source.type == 'function' then + end) + : case 'function' + : call(function (source) return 'f:' .. source.start, nil - elseif source.type == 'string' then + end) + : case 'string' + : call(function (source) return 'str:', nil - elseif source.type == 'integer' then - return 'int:' - elseif source.type == 'number' then - return 'num:' - elseif source.type == 'boolean' then - return 'bool:' - elseif source.type == 'nil' then + end) + : case 'integer' + : call(function (source) + return 'int:', nil + end) + : case 'number' + : call(function (source) + return 'num:', nil + end) + : case 'boolean' + : call(function (source) + return 'bool:', nil + end) + : case 'nil' + : call(function (source) return 'nil:', nil - elseif source.type == '...' then + end) + : case '...' + : call(function (source) return 'va:' .. source.start, nil - elseif source.type == 'varargs' then + end) + : case 'varargs' + : call(function (source) if source.node then return 'va:' .. source.node.start, nil end - elseif source.type == 'select' then - return ('s:%d%s%d'):format(source.start, RETURN_INDEX, source.sindex) - elseif source.type == 'call' then + end) + : case 'select' + : call(function (source) + return sformat('s:%d%s%d', source.start, RETURN_INDEX, source.sindex) + end) + : case 'call' + : call(function (source) local node = source.node if node.special == 'rawget' or node.special == 'rawset' then @@ -192,66 +235,111 @@ local function getKey(source) return nil, nil end if key.type == 'string' then - return ('%q'):format(key[1] or ''), tbl + return STRING_CHAR .. (key[1] or ''), tbl else return '', tbl end end return 'c:' .. source.finish, nil - elseif source.type == 'doc.class.name' - or source.type == 'doc.alias.name' - or source.type == 'doc.extends.name' then + end) + : case 'doc.class.name' + : case 'doc.alias.name' + : case 'doc.extends.name' + : call(function (source) local name = source[1] return 'dn:' .. name, nil - elseif source.type == 'doc.type.name' then + end) + : case 'doc.type.name' + : call(function (source) local name = source[1] if source.typeGeneric then return 'dg:' .. source.typeGeneric[name][1].start, nil else return 'dn:' .. name, nil end - elseif source.type == 'doc.see.name' then + end) + : case 'doc.see.name' + : call(function (source) local name = source[1] return 'dsn:' .. name, nil - elseif source.type == 'doc.class' then + end) + : case 'doc.class' + : call(function (source) return 'dc:' .. source.start - elseif source.type == 'doc.type' then + end) + : case 'doc.type' + : call(function (source) return 'dt:' .. source.start - elseif source.type == 'doc.param' then + end) + : case 'doc.param' + : call(function (source) return 'dp:' .. source.start - elseif source.type == 'doc.vararg' then + end) + : case 'doc.vararg' + : call(function (source) return 'dv:' .. source.start - elseif source.type == 'doc.field.name' then + end) + : case 'doc.field.name' + : call(function (source) return 'dfn:' .. source.start - elseif source.type == 'doc.type.enum' - or source.type == 'doc.resume' then + end) + : case 'doc.type.enum' + : case 'doc.resume' + : call(function (source) return 'de:' .. source.start - elseif source.type == 'doc.type.table' then + end) + : case 'doc.type.table' + : call(function (source) return 'dtable:' .. source.start - elseif source.type == 'doc.type.ltable' then + end) + : case 'doc.type.ltable' + : call(function (source) return 'dltable:' .. source.start - elseif source.type == 'doc.type.field' then + end) + : case 'doc.type.field' + : call(function (source) return 'dfield:' .. source.start - elseif source.type == 'doc.type.array' then + end) + : case 'doc.type.array' + : call(function (source) return 'darray:' .. source.finish - elseif source.type == 'doc.type.function' then + end) + : case 'doc.type.function' + : call(function (source) return 'dfun:' .. source.start, nil - elseif source.type == 'doc.see.field' then - return ('%q'):format(source[1]), source.parent.name - elseif source.type == 'generic.closure' then + end) + : case 'doc.see.field' + : call(function (source) + return STRING_CHAR .. (source[1]), source.parent.name + end) + : case 'generic.closure' + : call(function (source) return 'gc:' .. source.call.start, nil - elseif source.type == 'generic.value' then + end) + : case 'generic.value' + : call(function (source) local tail = '' if guide.getUri(source.closure.call) ~= guide.getUri(source.proto) then tail = URI_CHAR .. guide.getUri(source.closure.call) end - return ('gv:%s|%s%s'):format( - source.closure.call.start, - getKey(source.proto), - tail + return sformat('gv:%s|%s%s' + , source.closure.call.start + , getKey(source.proto) + , tail ) + end) + : getMap() + +---获取语法树单元的key +---@param source parser.guide.object +---@return string? key +---@return parser.guide.object? node +function getKey(source) + local f = getKeyMap[source.type] + if f then + return f(source) end - return nil, nil + return nil end local function getNodeKey(source) @@ -270,7 +358,6 @@ local function getNodeKey(source) return key, node end -local IDList = {} ---获取语法树单元的字符串ID ---@param source parser.guide.object ---@return string? id @@ -287,36 +374,26 @@ local function getID(source) return nil end local current = source - local index = 0 - while true do - if current.type == 'paren' then - current = current.exp - if not current then - return nil - end - goto CONTINUE - end - local id, node = getNodeKey(current) - if not id then - break - end - index = index + 1 - IDList[index] = id - if not node then - break + while current.type == 'paren' do + current = current.exp + if not current then + source._id = false + return nil end - current = node - ::CONTINUE:: end - if index == 0 then + local id, node = getNodeKey(current) + if not id then source._id = false return nil end - for i = index + 1, #IDList do - IDList[i] = nil + if node then + local pid = getID(node) + if not pid then + source._id = false + return nil + end + id = pid .. SPLIT_CHAR .. id end - util.revertTable(IDList) - local id = table.concat(IDList, SPLIT_CHAR) source._id = id return id end @@ -333,23 +410,24 @@ local function pushForward(noders, id, forwardID, info) or id == forwardID then return end - local node = getNode(noders, id) - if not node.forward then - node.forward = forwardID - node.finfo = info + if not noders.forward[id] then + noders.forward[id] = forwardID + noders.finfo[id] = info return end - if node.forward == forwardID then + if noders.forward[id] == forwardID then return end - if not node.forwards then - node.forwards = {} + local forwards = noders.forwards[id] + if not forwards then + forwards = {} + noders.forwards[id] = forwards end - if node.forwards[forwardID] ~= nil then + if forwards[forwardID] ~= nil then return end - node.forwards[forwardID] = info or false - node.forwards[#node.forwards+1] = forwardID + forwards[forwardID] = info or false + forwards[#forwards+1] = forwardID end ---添加关联的后退ID @@ -364,29 +442,32 @@ local function pushBackward(noders, id, backwardID, info) or id == backwardID then return end - local node = getNode(noders, id) - if not node.backward then - node.backward = backwardID - node.binfo = info + if not noders.backward[id] then + noders.backward[id] = backwardID + noders.binfo[id] = info return end - if node.backward == backwardID then + if noders.backward[id] == backwardID then return end - if not node.backwards then - node.backwards = {} + local backwards = noders.backwards[id] + if not backwards then + backwards = {} + noders.backwards[id] = backwards end - if node.backwards[backwardID] ~= nil then + if backwards[backwardID] ~= nil then return end - node.backwards[backwardID] = info or false - node.backwards[#node.backwards+1] = backwardID + backwards[backwardID] = info or false + backwards[#backwards+1] = backwardID end ---@class noder local m = {} m.SPLIT_CHAR = SPLIT_CHAR +m.STRING_CHAR = STRING_CHAR +m.STRING_FIELD = STRING_FIELD m.RETURN_INDEX = RETURN_INDEX m.PARAM_INDEX = PARAM_INDEX m.TABLE_KEY = TABLE_KEY @@ -416,47 +497,50 @@ local function getDocStateWithoutCrossFunction(obj) error('guide.getDocState overstack') end +local dontPushSourceMap = util.arrayToHash { + 'str:', 'nil:', 'num:', 'int:', 'bool:' +} + ---添加关联单元 ---@param noders noders ---@param source parser.guide.object function m.pushSource(noders, source, id) - id = id or m.getID(source) + id = id or getID(source) if not id then return end - if id == 'str:' - or id == 'nil:' - or id == 'num:' - or id == 'int:' - or id == 'bool:' then + if dontPushSourceMap[id] then return end - local node = getNode(noders, id) - if not node.source then - node.source = source + if not noders.source[id] then + noders.source[id] = source return end - if not node.sources then - node.sources = {} + local sources = noders.sources[id] + if not sources then + sources = {} + noders.sources[id] = sources end - node.sources[#node.sources+1] = source + sources[#sources+1] = source end local DUMMY_FUNCTION = function () end ---遍历关联单元 ----@param node node +---@param noders noders +---@param id node.id ---@return fun():parser.guide.object -function m.eachSource(node) - if not node.source then +function m.eachSource(noders, id) + local source = noders.source[id] + if not source then return DUMMY_FUNCTION end local index - local sources = node.sources + local sources = noders.sources[id] return function () if not index then index = 0 - return node.source + return source end if not sources then return nil @@ -467,18 +551,20 @@ function m.eachSource(node) end ---遍历forward ----@param node node +---@param noders noders +---@param id node.id ---@return fun():string, string -function m.eachForward(node) - if not node.forward then +function m.eachForward(noders, id) + local forward = noders.forward[id] + if not forward then return DUMMY_FUNCTION end local index - local forwards = node.forwards + local forwards = noders.forwards[id] return function () if not index then index = 0 - return node.forward, node.finfo + return forward, noders.finfo[id] end if not forwards then return nil @@ -491,18 +577,20 @@ function m.eachForward(node) end ---遍历backward ----@param node node ----@return fun():string, node.info -function m.eachBackward(node) - if not node.backward then +---@param noders noders +---@param id node.id +---@return fun():string, string +function m.eachBackward(noders, id) + local backward = noders.backward[id] + if not backward then return DUMMY_FUNCTION end local index - local backwards = node.backwards + local backwards = noders.backwards[id] return function () if not index then index = 0 - return node.backward, node.binfo + return backward, noders.binfo[id] end if not backwards then return nil @@ -523,6 +611,7 @@ local function bindValue(noders, source, id) if not valueID then return end + m.compilePartNodes(noders, value) if source.type == 'getlocal' or source.type == 'setlocal' then source = source.node @@ -540,7 +629,7 @@ local function bindValue(noders, source, id) reject = 'set', }) -- 参数/call禁止反向查找赋值 - local valueType = valueID:match '^(.-:).' + local valueType = smatch(valueID, '^(.-:).') if not valueType then return end @@ -583,12 +672,12 @@ local function compileCallParam(noders, call, sourceID) if firstIndex > 0 and callArg.type == 'function' then if callArg.args then for secondIndex, funcParam in ipairs(callArg.args) do - local paramID = ('%s%s%s%s%s'):format( - nodeID, - PARAM_INDEX, - firstIndex, - PARAM_INDEX, - secondIndex + local paramID = sformat('%s%s%s%s%s' + , nodeID + , PARAM_INDEX + , firstIndex + , PARAM_INDEX + , secondIndex ) pushForward(noders, getID(funcParam), paramID) end @@ -616,10 +705,10 @@ local function compileCallReturn(noders, call, sourceID, returnIndex) local metaID = getID(call.args and call.args[2]) local indexID if metaID then - indexID = ('%s%s%q'):format( - metaID, - SPLIT_CHAR, - '__index' + indexID = sformat('%s%s%s' + , metaID + , STRING_FIELD + , '__index' ) end pushForward(noders, sourceID, tblID) @@ -628,7 +717,7 @@ local function compileCallReturn(noders, call, sourceID, returnIndex) if field then return true end - return id:sub(1, 2) ~= 'f:' + return ssub(id, 1, 2) ~= 'f:' end, filterValid = function (id, field) return not field @@ -641,7 +730,7 @@ local function compileCallReturn(noders, call, sourceID, returnIndex) if node.special == 'require' then local arg1 = call.args and call.args[1] if arg1 and arg1.type == 'string' then - getNode(noders, sourceID).require = arg1[1] + noders.require[sourceID] = arg1[1] end pushBackward(noders, callID, sourceID, { deep = true, @@ -658,10 +747,10 @@ local function compileCallReturn(noders, call, sourceID, returnIndex) if not funcID then return end - local pfuncXID = ('%s%s%s'):format( - funcID, - RETURN_INDEX, - index + local pfuncXID = sformat('%s%s%s' + , funcID + , RETURN_INDEX + , index ) pushForward(noders, sourceID, pfuncXID) pushBackward(noders, pfuncXID, sourceID, { @@ -669,179 +758,153 @@ local function compileCallReturn(noders, call, sourceID, returnIndex) }) return end - local funcXID = ('%s%s%s'):format( - nodeID, - RETURN_INDEX, - returnIndex + local funcXID = sformat('%s%s%s' + , nodeID + , RETURN_INDEX + , returnIndex ) - getNode(noders, sourceID).call = call + noders.call[sourceID] = call pushForward(noders, sourceID, funcXID) pushBackward(noders, funcXID, sourceID, { deep = true, }) end -function m.compileDocValue(noders, tp, id, source) - if tp == 'doc.type' then +local specialMap = util.arrayToHash { + 'require', 'dofile', 'loadfile', + 'rawset', 'rawget', 'setmetatable', +} + +local compileNodeMap +compileNodeMap = util.switch() + : case 'string' + : call(function (noders, id, source) + pushForward(noders, id, 'str:') + end) + : case 'boolean' + : call(function (noders, id, source) + pushForward(noders, id, 'dn:boolean') + end) + : case 'number' + : call(function (noders, id, source) + pushForward(noders, id, 'dn:number') + end) + : case 'integer' + : call(function (noders, id, source) + pushForward(noders, id, 'dn:integer') + end) + : case 'nil' + : call(function (noders, id, source) + pushForward(noders, id, 'dn:nil') + end) + -- self -> mt:xx + : case 'local' + : call(function (noders, id, source) + if source[1] ~= 'self' then + return + end + local func = guide.getParentFunction(source) + if func.isGeneric then + return + end + if source.parent.type ~= 'funcargs' then + return + end + local setmethod = func.parent + -- guess `self` + if setmethod and ( setmethod.type == 'setmethod' + or setmethod.type == 'setfield' + or setmethod.type == 'setindex') then + pushForward(noders, id, getID(setmethod.node)) + pushBackward(noders, getID(setmethod.node), id, { + deep = true, + }) + end + end) + : case 'doc.type' + : call(function (noders, id, source) if source.bindSources then for _, src in ipairs(source.bindSources) do pushForward(noders, getID(src), id) pushForward(noders, id, getID(src)) + m.compilePartNodes(noders, src) end end for _, enumUnit in ipairs(source.enums) do pushForward(noders, id, getID(enumUnit)) + m.compilePartNodes(noders, enumUnit) end for _, resumeUnit in ipairs(source.resumes) do pushForward(noders, id, getID(resumeUnit)) + m.compilePartNodes(noders, resumeUnit) end for _, typeUnit in ipairs(source.types) do local unitID = getID(typeUnit) pushForward(noders, id, unitID) + m.compilePartNodes(noders, typeUnit) if source.bindSources then for _, src in ipairs(source.bindSources) do pushBackward(noders, unitID, getID(src)) end end end - end - if tp == 'doc.type.table' then + end) + : case 'doc.type.table' + : call(function (noders, id, source) if source.tkey then - local keyID = ('%s%s'):format( - id, - TABLE_KEY - ) + local keyID = id .. TABLE_KEY pushForward(noders, keyID, getID(source.tkey)) end if source.tvalue then - local valueID = ('%s%s'):format( - id, - ANY_FIELD - ) + local valueID = id .. ANY_FIELD pushForward(noders, valueID, getID(source.tvalue)) end - end - if tp == 'doc.type.ltable' then + end) + : case 'doc.type.ltable' + : call(function (noders, id, source) local firstField = source.fields[1] if not firstField then return end - local keyID = ('%s%s'):format( - id, - WEAK_TABLE_KEY - ) - local valueID = ('%s%s'):format( - id, - WEAK_ANY_FIELD - ) + local keyID = id .. WEAK_TABLE_KEY + local valueID = id .. WEAK_ANY_FIELD pushForward(noders, keyID, 'dn:string') pushForward(noders, valueID, getID(firstField.extends)) for _, field in ipairs(source.fields) do - local extendsID = ('%s%s%q'):format( - id, - SPLIT_CHAR, - field.name[1] - ) + local fname = field.name[1] + local extendsID + if type(fname) == 'string' then + extendsID = sformat('%s%s%s' + , id + , STRING_FIELD + , fname + ) + else + extendsID = sformat('%s%s%s' + , id + , SPLIT_CHAR + , fname + ) + end pushForward(noders, extendsID, getID(field)) pushForward(noders, extendsID, getID(field.extends)) end - end - if tp == 'doc.type.array' then + end) + : case 'doc.type.array' + : call(function (noders, id, source) if source.node then - local nodeID = ('%s%s'):format( - id, - ANY_FIELD - ) + local nodeID = id .. ANY_FIELD pushForward(noders, nodeID, getID(source.node)) end - local keyID = ('%s%s'):format( - id, - TABLE_KEY - ) + local keyID = id .. TABLE_KEY pushForward(noders, keyID, 'dn:integer') - end -end - ----@param noders noders ----@param source parser.guide.object ----@return parser.guide.object[] -function m.compileNode(noders, source) - local id = getID(source) - bindValue(noders, source, id) - if source.special == 'setmetatable' - or source.special == 'require' - or source.special == 'dofile' - or source.special == 'loadfile' - or source.special == 'rawset' - or source.special == 'rawget' then - local node = getNode(noders, id) - node.skip = true - end - if source.type == 'string' then - pushForward(noders, id, 'str:') - end - if source.type == 'boolean' then - pushForward(noders, id, 'dn:boolean') - end - if source.type == 'number' then - pushForward(noders, id, 'dn:number') - end - if source.type == 'integer' then - pushForward(noders, id, 'dn:integer') - end - if source.type == 'nil' then - pushForward(noders, id, 'dn:nil') - end - -- self -> mt:xx - if source.type == 'local' and source[1] == 'self' then - local func = guide.getParentFunction(source) - if func.isGeneric then - return - end - if source.parent.type ~= 'funcargs' then - return - end - local setmethod = func.parent - -- guess `self` - if setmethod and ( setmethod.type == 'setmethod' - or setmethod.type == 'setfield' - or setmethod.type == 'setindex') then - pushForward(noders, id, getID(setmethod.node)) - pushBackward(noders, getID(setmethod.node), id, { - deep = true, - }) - end - end - -- 分解 @type - --if source.type == 'doc.type' then - -- if source.bindSources then - -- for _, src in ipairs(source.bindSources) do - -- pushForward(noders, getID(src), id) - -- pushForward(noders, id, getID(src)) - -- end - -- end - -- for _, enumUnit in ipairs(source.enums) do - -- pushForward(noders, id, getID(enumUnit)) - -- end - -- for _, resumeUnit in ipairs(source.resumes) do - -- pushForward(noders, id, getID(resumeUnit)) - -- end - -- for _, typeUnit in ipairs(source.types) do - -- local unitID = getID(typeUnit) - -- pushForward(noders, id, unitID) - -- if source.bindSources then - -- for _, src in ipairs(source.bindSources) do - -- pushBackward(noders, unitID, getID(src)) - -- end - -- end - -- end - --end - -- 分解 @alias - if source.type == 'doc.alias' then + end) + : case 'doc.alias' + : call(function (noders, id, source) pushForward(noders, getID(source.alias), getID(source.extends)) - end - -- 分解 @class - if source.type == 'doc.class' then + end) + : case 'doc.class' + : call(function (noders, id, source) pushForward(noders, id, getID(source.class)) pushForward(noders, getID(source.class), id) if source.extends then @@ -858,47 +921,60 @@ function m.compileNode(noders, source) for _, field in ipairs(source.fields) do local key = field.field[1] if key then - local keyID = ('%s%s%q'):format( - id, - SPLIT_CHAR, - key - ) + local keyID + if type(key) == 'string' then + keyID = sformat('%s%s%s' + , id + , STRING_FIELD + , key + ) + else + keyID = sformat('%s%s%s' + , id + , SPLIT_CHAR + , key + ) + end pushForward(noders, keyID, getID(field.field)) pushForward(noders, getID(field.field), keyID) pushForward(noders, keyID, getID(field.extends)) pushBackward(noders, getID(field.extends), keyID) end end - end - if source.type == 'doc.param' then + end) + : case 'doc.param' + : call(function (noders, id, source) pushForward(noders, id, getID(source.extends)) for _, src in ipairs(source.bindSources) do if src.type == 'local' and src.parent.type == 'in' then pushForward(noders, getID(src), id) end end - end - if source.type == 'doc.vararg' then + end) + : case 'doc.vararg' + : call(function (noders, id, source) pushForward(noders, getID(source), getID(source.vararg)) - end - if source.type == 'doc.see' then + end) + : case 'doc.see' + : call(function (noders, id, source) local nameID = getID(source.name) - local classID = nameID:gsub('^dsn:', 'dn:') + local classID = sgsub(nameID, '^dsn:', 'dn:') pushForward(noders, nameID, classID) if source.field then local fieldID = getID(source.field) - local fieldClassID = fieldID:gsub('^dsn:', 'dn:') + local fieldClassID = sgsub(fieldID, '^dsn:', 'dn:') pushForward(noders, fieldID, fieldClassID) end - end - m.compileDocValue(noders, source.type, id, source) - if source.type == 'call' then + end) + : case 'call' + : call(function (noders, id, source) if source.parent.type ~= 'select' then compileCallReturn(noders, source, id, 1) end compileCallParam(noders, source, id) - end - if source.type == 'select' then + end) + : case 'select' + : call(function (noders, id, source) if source.vararg.type == 'call' then local call = source.vararg compileCallReturn(noders, call, id, source.sindex) @@ -906,22 +982,23 @@ function m.compileNode(noders, source) if source.vararg.type == 'varargs' then pushForward(noders, id, getID(source.vararg)) end - end - if source.type == 'doc.type.function' then + end) + : case 'doc.type.function' + : call(function (noders, id, source) if source.returns then for index, rtn in ipairs(source.returns) do - local returnID = ('%s%s%s'):format( - id, - RETURN_INDEX, - index + local returnID = sformat('%s%s%s' + , id + , RETURN_INDEX + , index ) pushForward(noders, returnID, getID(rtn)) end for index, param in ipairs(source.args) do - local paramID = ('%s%s%s'):format( - id, - PARAM_INDEX, - index + local paramID = sformat('%s%s%s' + , id + , PARAM_INDEX + , index ) pushForward(noders, paramID, getID(param.extends)) end @@ -936,52 +1013,38 @@ function m.compileNode(noders, source) end end) end - end - if source.type == 'doc.type.name' then + end) + : case 'doc.type.name' + : call(function (noders, id, source) local uri = guide.getUri(source) - collector.subscribe(uri, id, getNode(noders, id)) - end - if source.type == 'doc.class.name' - or source.type == 'doc.alias.name' then + collector.subscribe(uri, id, noders) + end) + : case 'doc.class.name' + : case 'doc.alias.name' + : call(function (noders, id, source) local uri = guide.getUri(source) - collector.subscribe(uri, id, getNode(noders, id)) + collector.subscribe(uri, id, noders) local defID = 'def:' .. id - collector.subscribe(uri, defID, getNode(noders, defID)) + collector.subscribe(uri, defID, noders) m.pushSource(noders, source, defID) local defAnyID = 'def:dn:' - collector.subscribe(uri, defAnyID, getNode(noders, defAnyID)) + collector.subscribe(uri, defAnyID, noders) m.pushSource(noders, source, defAnyID) - end - if id and id:sub(1, 2) == 'g:' then - local uri = guide.getUri(source) - collector.subscribe(uri, id, getNode(noders, id)) - if guide.isSet(source) then - - local defID = 'def:' .. id - collector.subscribe(uri, defID, getNode(noders, defID)) - m.pushSource(noders, source, defID) - - if guide.isGlobal(source) then - local defAnyID = 'def:g:' - collector.subscribe(uri, defAnyID, getNode(noders, defAnyID)) - m.pushSource(noders, source, defAnyID) - end - end - end - -- 将函数的返回值映射到具体的返回值上 - if source.type == 'function' then + end) + : case 'function' + : call(function (noders, id, source) local hasDocReturn = {} -- 检查 luadoc if source.bindDocs then for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.return' then for _, rtn in ipairs(doc.returns) do - local fullID = ('%s%s%s'):format( - id, - RETURN_INDEX, - rtn.returnIndex + local fullID = sformat('%s%s%s' + , id + , RETURN_INDEX + , rtn.returnIndex ) pushForward(noders, fullID, getID(rtn)) for _, typeUnit in ipairs(rtn.types) do @@ -1001,10 +1064,10 @@ function m.compileNode(noders, source) if param then pushForward(noders, getID(param), getID(doc)) param.docParam = doc - local paramID = ('%s%s%s'):format( - id, - PARAM_INDEX, - paramIndex + local paramID = sformat('%s%s%s' + , id + , PARAM_INDEX + , paramIndex ) pushForward(noders, paramID, getID(doc.extends)) end @@ -1041,10 +1104,10 @@ function m.compileNode(noders, source) end end for index, rtnObjs in ipairs(returns) do - local returnID = ('%s%s%s'):format( - id, - RETURN_INDEX, - index + local returnID = sformat('%s%s%s' + , id + , RETURN_INDEX + , index ) for _, rtnObj in ipairs(rtnObjs) do pushForward(noders, returnID, getID(rtnObj)) @@ -1055,31 +1118,20 @@ function m.compileNode(noders, source) end end end - end - if source.type == 'table' then + end) + : case 'table' + : call(function (noders, id, source) local firstField = source[1] if firstField then if firstField.type == 'varargs' then - local keyID = ('%s%s'):format( - id, - TABLE_KEY - ) - local valueID = ('%s%s'):format( - id, - ANY_FIELD - ) + local keyID = id .. TABLE_KEY + local valueID = id .. ANY_FIELD source.array = firstField pushForward(noders, keyID, 'dn:integer') pushForward(noders, valueID, getID(firstField)) else - local keyID = ('%s%s'):format( - id, - WEAK_TABLE_KEY - ) - local valueID = ('%s%s'):format( - id, - WEAK_ANY_FIELD - ) + local keyID = id .. WEAK_TABLE_KEY + local valueID = id .. WEAK_ANY_FIELD if firstField.type == 'tablefield' then pushForward(noders, keyID, 'dn:string') pushForward(noders, valueID, getID(firstField.value)) @@ -1092,8 +1144,9 @@ function m.compileNode(noders, source) end end end - end - if source.type == 'main' then + end) + : case 'main' + : call(function (noders, id, source) if source.returns then for _, rtn in ipairs(source.returns) do local rtnObj = rtn[1] @@ -1105,19 +1158,21 @@ function m.compileNode(noders, source) end end end - end - if source.type == 'generic.closure' then + end) + : case 'generic.closure' + : call(function (noders, id, source) for i, rtn in ipairs(source.returns) do - local closureID = ('%s%s%s'):format( - id, - RETURN_INDEX, - i + local closureID = sformat('%s%s%s' + , id + , RETURN_INDEX + , i ) local returnID = getID(rtn) pushForward(noders, closureID, returnID) end - end - if source.type == 'generic.value' then + end) + : case 'generic.value' + : call(function (noders, id, source) local proto = source.proto local closure = source.closure local upvalues = closure.upvalues @@ -1130,34 +1185,57 @@ function m.compileNode(noders, source) end end end - --if proto.type == 'doc.type' then - -- for _, tp in ipairs(source.types) do - -- pushForward(noders, id, getID(tp)) - -- pushBackward(noders, getID(tp), id) - -- end - --end - m.compileDocValue(noders, proto.type, id, source) + local f = compileNodeMap[proto.type] + if f then + f(noders, id, source) + end + end) + : getMap() + +---@param noders noders +---@param source parser.guide.object +---@return parser.guide.object[] +function m.compileNode(noders, source) + if source._noded then + return end -end + source._noded = true + m.pushSource(noders, source) + local id = getID(source) + bindValue(noders, source, id) ----根据ID来获取所有的node ----@param root parser.guide.object ----@param id string ----@return node? -function m.getNodeByID(root, id) - root = guide.getRoot(root) - local noders = root._noders - if not noders then - return nil + if specialMap[source.special] then + noders.skip[id] = true + end + + local f = compileNodeMap[source.type] + if f then + f(noders, id, source) + end + + if id and ssub(id, 1, 2) == 'g:' then + local uri = guide.getUri(source) + collector.subscribe(uri, id, noders) + if guide.isSet(source) then + + local defID = 'def:' .. id + collector.subscribe(uri, defID, noders) + m.pushSource(noders, source, defID) + + if guide.isGlobal(source) then + local defAnyID = 'def:g:' + collector.subscribe(uri, defAnyID, noders) + m.pushSource(noders, source, defAnyID) + end + end end - return noders[id] end ---根据ID来获取第一个节点的ID ---@param id string ---@return string function m.getFirstID(id) - local firstID, count = id:match(FIRST_REGEX) + local firstID, count = smatch(id, FIRST_REGEX) if count == 0 then return nil end @@ -1171,7 +1249,7 @@ end ---@param id string ---@return string function m.getHeadID(id) - local headID, count = id:match(HEAD_REGEX) + local headID, count = smatch(id, HEAD_REGEX) if count == 0 then return nil end @@ -1185,7 +1263,7 @@ end ---@param id string ---@return string function m.getLastID(id) - local lastID, count = id:gsub(LAST_REGEX, '') + local lastID, count = sgsub(id, LAST_REGEX, '') if count == 0 then return nil end @@ -1202,7 +1280,7 @@ function m.getIDLength(id) if not id then return 0 end - local _, count = id:gsub(SPLIT_CHAR, SPLIT_CHAR) + local _, count = sgsub(id, SPLIT_CHAR, SPLIT_CHAR) return count + 1 end @@ -1214,11 +1292,11 @@ function m.hasField(id) if firstID == id or not firstID then return false end - local nextChar = id:sub(#firstID + 1, #firstID + 1) + local nextChar = ssub(id, #firstID + 1, #firstID + 1) if nextChar ~= SPLIT_CHAR then return false end - local next2Char = id:sub(#firstID + 2, #firstID + 2) + local next2Char = ssub(id, #firstID + 2, #firstID + 2) if next2Char == RETURN_INDEX or next2Char == PARAM_INDEX then return false @@ -1231,7 +1309,7 @@ end ---@return uri? string ---@return string id function m.getUriAndID(id) - local uri, newID = id:match(URI_REGEX) + local uri, newID = smatch(id, URI_REGEX) return uri, newID end @@ -1241,15 +1319,20 @@ function m.isCommonField(field) if not field then return false end - if field:sub(1, #RETURN_INDEX) == RETURN_INDEX then + if ssub(field, 1, #RETURN_INDEX) == RETURN_INDEX then return false end - if field:sub(1, #PARAM_INDEX) == PARAM_INDEX then + if ssub(field, 1, #PARAM_INDEX) == PARAM_INDEX then return false end return true end +function m.isGlobalID(id) + return ssub(id, 1, 2) == 'g:' + or ssub(id, 1, 3) == 'dn:' +end + ---获取source的ID ---@param source parser.guide.object ---@return string @@ -1265,15 +1348,15 @@ function m.getKey(source) end ---清除临时id(用于泛型的临时对象) ----@param root parser.guide.object +---@param noders noders ---@param id string -function m.removeID(root, id) +function m.removeID(noders, id) if not id then return end - root = guide.getRoot(root) - local noders = root._noders - noders[id] = nil + for _, t in next, noders do + t[id] = nil + end end ---寻找doc的主体 @@ -1288,36 +1371,178 @@ end function m.getNoders(source) local root = guide.getRoot(source) if not root._noders then - root._noders = {} + ---@type noders + root._noders = { + source = {}, + sources = {}, + forward = {}, + finfo = {}, + forwards = {}, + backward = {}, + binfo = {}, + backwards = {}, + call = {}, + require = {}, + skip = {}, + } end return root._noders end +---获取对象的noders +---@param uri uri +---@return noders +function m.getNodersByUri(uri) + local state = files.getState(uri) + if not state then + return nil + end + return m.getNoders(state.ast) +end + ---编译整个文件的node ---@param source parser.guide.object ---@return table -function m.compileNodes(source) +function m.compileAllNodes(source) local root = guide.getRoot(source) local noders = m.getNoders(source) - if next(noders) then + if root._initedNoders then return noders end + root._initedNoders = true log.debug('compileNodes:', guide.getUri(root)) collector.dropUri(guide.getUri(root)) guide.eachSource(root, function (src) - m.pushSource(noders, src) m.compileNode(noders, src) end) - log.debug('compileNodes finish:', guide.getUri(root)) + log.debug('compileNodes finish:', files.getOriginUri(guide.getUri(root))) return noders end +local partNodersMap = util.switch() + : case 'local' + : call(function (noders, source) + local refs = source.ref + if refs then + for i = 1, #refs do + local ref = refs[i] + m.compilePartNodes(noders, ref) + end + end + + local nxt = source.next + if nxt then + m.compilePartNodes(noders, nxt) + end + + local node = getMethodNode(source) + if node then + m.compilePartNodes(noders, node) + end + end) + : case 'setlocal' + : case 'getlocal' + : call(function (noders, source) + m.compilePartNodes(noders, source.node) + + local nxt = source.next + if nxt then + m.compilePartNodes(noders, nxt) + end + end) + : case 'setfield' + : case 'getfield' + : case 'setmethod' + : case 'getmethod' + : call(function (noders, source) + local node = source.node + m.compilePartNodes(noders, node) + + local nxt = source.next + if nxt then + m.compilePartNodes(noders, nxt) + end + end) + : case 'setglobal' + : case 'getglobal' + : call(function (noders, source) + local nxt = source.next + if nxt then + m.compilePartNodes(noders, nxt) + end + end) + : case 'label' + : call(function (noders, source) + local refs = source.ref + if not refs then + return + end + for i = 1, #refs do + local ref = refs[i] + m.compilePartNodes(noders, ref) + end + end) + : case 'goto' + : call(function (noders, source) + m.compilePartNodes(noders, source.node) + end) + : case 'table' + : call(function (noders, source) + for i = 1, #source do + local field = source[i] + m.compilePartNodes(noders, field) + end + end) + : case 'tablefield' + : case 'tableindex' + : call(function (noders, source) + m.compilePartNodes(noders, source.parent) + end) + : getMap() + +---编译Class的node +---@param noders noders +---@param source parser.guide.object +---@return table +function m.compilePartNodes(noders, source) + do return end + if source._noded then + return + end + m.compileNode(noders, source) + local f = partNodersMap[source.type] + if f then + f(noders, source) + end + + local parent = source.parent + if parent.value == source then + m.compilePartNodes(noders, parent) + end +end + +---编译全局变量的node +---@param root parser.guide.object +---@return table +function m.compileGlobalNodes(root) + local noders = m.getNoders(root) + local env = guide.getENV(root) + m.compilePartNodes(noders, env) + + local docs = root.docs + for i = 1, #docs do + local doc = docs[i] + m.compileNode(noders, doc) + end +end + files.watch(function (ev, uri) uri = files.asKey(uri) if ev == 'update' then local state = files.getState(uri) if state then - m.compileNodes(state.ast) + m.compileAllNodes(state.ast) + --m.compileGlobalNodes(state.ast) end end if ev == 'remove' then diff --git a/script/core/reference.lua b/script/core/reference.lua index 109bf601..8f113a8d 100644 --- a/script/core/reference.lua +++ b/script/core/reference.lua @@ -7,8 +7,8 @@ local findSource = require 'core.find-source' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = searcher.getUri(a.target) - local u2 = searcher.getUri(b.target) + local u1 = guide.getUri(a.target) + local u2 = guide.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -20,7 +20,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = searcher.getUri(res) + local uri = guide.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else diff --git a/script/core/searcher.lua b/script/core/searcher.lua index 6554779f..10571c03 100644 --- a/script/core/searcher.lua +++ b/script/core/searcher.lua @@ -5,6 +5,53 @@ local generic = require 'core.generic' local ws = require 'workspace' local vm = require 'vm.vm' local collector = require 'core.collector' +local util = require 'utility' + +local TRACE = TRACE +local FOOTPRINT = FOOTPRINT +local TEST = TEST +local log = log +local select = select +local tostring = tostring +local next = next +local error = error +local type = type +local setmetatable = setmetatable +local tconcat = table.concat +local ssub = string.sub +local sfind = string.find +local sformat = string.format + +local getUri = guide.getUri +local getRoot = guide.getRoot + +local ceach = collector.each + +local getNoders = noder.getNoders +local getID = noder.getID +local getLastID = noder.getLastID +local removeID = noder.removeID +local getNodersByUri = noder.getNodersByUri +local getFirstID = noder.getFirstID +local getHeadID = noder.getHeadID +local eachForward = noder.eachForward +local getUriAndID = noder.getUriAndID +local eachBackward = noder.eachBackward +local eachSource = noder.eachSource +local compileAllNodes = noder.compileAllNodes +local compilePartNoders = noder.compilePartNodes +local isGlobalID = noder.isGlobalID + +local SPLIT_CHAR = noder.SPLIT_CHAR +local RETURN_INDEX = noder.RETURN_INDEX +local TABLE_KEY = noder.TABLE_KEY +local STRING_CHAR = noder.STRING_CHAR +local STRING_FIELD = noder.STRING_FIELD +local WEAK_TABLE_KEY = noder.WEAK_TABLE_KEY +local ANY_FIELD = noder.ANY_FIELD +local WEAK_ANY_FIELD = noder.WEAK_ANY_FIELD + +_ENV = nil local ignoredSources = { ['int:'] = true, @@ -28,16 +75,99 @@ local ignoredIDs = { ['dn:thread'] = true, } +---@class searcher local m = {} ---@alias guide.searchmode '"ref"'|'"def"'|'"field"'|'"allref"'|'"alldef"' +local pushDefResultsMap = util.switch() + : case 'local' + : case 'setlocal' + : case 'setglobal' + : case 'label' + : case 'setfield' + : case 'setmethod' + : case 'setindex' + : case 'tableindex' + : case 'tablefield' + : case 'tableexp' + : case 'function' + : case 'table' + : case 'doc.class.name' + : case 'doc.alias.name' + : case 'doc.field.name' + : case 'doc.type.enum' + : case 'doc.resume' + : case 'doc.type.array' + : case 'doc.type.table' + : case 'doc.type.ltable' + : case 'doc.type.field' + : case 'doc.type.function' + : call(function (source, status) + return true + end) + : case 'call' + : call(function (source, status) + if source.node.special == 'rawset' then + return true + end + end) + : getMap() + +local pushRefResultsMap = util.switch() + : case 'local' + : case 'setlocal' + : case 'getlocal' + : case 'setglobal' + : case 'getglobal' + : case 'label' + : case 'goto' + : case 'setfield' + : case 'getfield' + : case 'setmethod' + : case 'getmethod' + : case 'setindex' + : case 'getindex' + : case 'tableindex' + : case 'tablefield' + : case 'tableexp' + : case 'function' + : case 'table' + : case 'string' + : case 'boolean' + : case 'number' + : case 'integer' + : case 'nil' + : case 'doc.class.name' + : case 'doc.type.name' + : case 'doc.alias.name' + : case 'doc.extends.name' + : case 'doc.field.name' + : case 'doc.type.enum' + : case 'doc.resume' + : case 'doc.type.array' + : case 'doc.type.table' + : case 'doc.type.ltable' + : case 'doc.type.field' + : case 'doc.type.function' + : call(function (source, status) + return true + end) + : case 'call' + : call(function (source, status) + if source.node.special == 'rawset' + or source.node.special == 'rawget' then + return true + end + end) + : getMap() + ---添加结果 ---@param status guide.status ---@param mode guide.searchmode ---@param source parser.guide.object ---@param force boolean -function m.pushResult(status, mode, source, force) +local function pushResult(status, mode, source, force) if not source then return end @@ -51,108 +181,28 @@ function m.pushResult(status, mode, source, force) results[#results+1] = source return end - local parent = source.parent + if mode == 'def' or mode == 'alldef' then - if source.type == 'local' - or source.type == 'setlocal' - or source.type == 'setglobal' - or source.type == 'label' - or source.type == 'setfield' - or source.type == 'setmethod' - or source.type == 'setindex' - or source.type == 'tableindex' - or source.type == 'tablefield' - or source.type == 'tableexp' - or source.type == 'function' - or source.type == 'table' - or source.type == 'doc.class.name' - or source.type == 'doc.alias.name' - or source.type == 'doc.field.name' - or source.type == 'doc.type.enum' - or source.type == 'doc.resume' - or source.type == 'doc.type.array' - or source.type == 'doc.type.table' - or source.type == 'doc.type.ltable' - or source.type == 'doc.type.field' - or source.type == 'doc.type.function' then + local f = pushDefResultsMap[source.type] + if f and f(source, status) then results[#results+1] = source return end - if source.type == 'call' then - if source.node.special == 'rawset' then - results[#results+1] = source - end - end - if parent.type == 'return' then - if noder.getID(source) ~= status.id then - results[#results+1] = source - end - end elseif mode == 'ref' or mode == 'field' or mode == 'allref' then - if source.type == 'local' - or source.type == 'setlocal' - or source.type == 'getlocal' - or source.type == 'setglobal' - or source.type == 'getglobal' - or source.type == 'label' - or source.type == 'goto' - or source.type == 'setfield' - or source.type == 'getfield' - or source.type == 'setmethod' - or source.type == 'getmethod' - or source.type == 'setindex' - or source.type == 'getindex' - or source.type == 'tableindex' - or source.type == 'tablefield' - or source.type == 'tableexp' - or source.type == 'function' - or source.type == 'table' - or source.type == 'string' - or source.type == 'boolean' - or source.type == 'number' - or source.type == 'integer' - or source.type == 'nil' - or source.type == 'doc.class.name' - or source.type == 'doc.type.name' - or source.type == 'doc.alias.name' - or source.type == 'doc.extends.name' - or source.type == 'doc.field.name' - or source.type == 'doc.type.enum' - or source.type == 'doc.resume' - or source.type == 'doc.type.array' - or source.type == 'doc.type.table' - or source.type == 'doc.type.ltable' - or source.type == 'doc.type.field' - or source.type == 'doc.type.function' then + local f = pushRefResultsMap[source.type] + if f and f(source, status) then results[#results+1] = source return end - if source.type == 'call' then - if source.node.special == 'rawset' - or source.node.special == 'rawget' then - results[#results+1] = source - end - end - if parent.type == 'return' then - if noder.getID(source) ~= status.id then - results[#results+1] = source - end - end end -end ----获取uri ----@param obj parser.guide.object ----@return uri -function m.getUri(obj) - if obj.uri then - return obj.uri - end - local root = guide.getRoot(obj) - if root then - return root.uri + local parent = source.parent + if parent.type == 'return' then + if source ~= status.source then + results[#results+1] = source + return + end end - return '' end ---@param obj parser.guide.object @@ -190,26 +240,6 @@ function m.getObjectValue(obj) return nil end -local function checkLock(status, k1, k2) - local locks = status.lock - local lock1 = locks[k1] - if not lock1 then - lock1 = {} - locks[k1] = lock1 - end - if lock1[''] then - return true - end - if k2 == nil then - k2 = '' - end - if lock1[k2] then - return true - end - lock1[k2] = true - return false -end - local strs = {} local function footprint(status, ...) if TRACE then @@ -220,127 +250,154 @@ local function footprint(status, ...) for i = 1, n do strs[i] = tostring(select(i, ...)) end - status.footprint[#status.footprint+1] = table.concat(strs, '\t', 1, n) - end -end - -local function crossSearch(status, uri, expect, mode, sourceUri) - if status.dontCross > 0 then - return - end - if checkLock(status, uri, expect) then - return + status.footprint[#status.footprint+1] = tconcat(strs, '\t', 1, n) end - footprint(status, 'crossSearch', uri, expect) - m.searchRefsByID(status, uri, expect, mode) - --status.lock[uri] = nil - footprint(status, 'crossSearch finish, back to:', sourceUri) end local function checkCache(status, uri, expect, mode) - local cache = vm.getCache('search:' .. mode) + local cache = status.cache local fileCache = cache[uri] - if not fileCache then - fileCache = {} - cache[uri] = fileCache - end - if fileCache[expect] then - for _, res in ipairs(fileCache[expect]) do - m.pushResult(status, mode, res, true) + local results = fileCache[expect] + if results then + for i = 1, #results do + local res = results[i] + pushResult(status, mode, res, true) end return true end - fileCache[expect] = status.results - return false + return false, function () + fileCache[expect] = status.results + if mode == 'def' + or mode == 'alldef' then + return + end + for id in next, status.ids do + fileCache[id] = status.results + end + end end local function stop(status, msg) if TEST then if FOOTPRINT then - log.debug(table.concat(status.footprint, '\n')) + log.debug(status.mode) + log.debug(tconcat(status.footprint, '\n')) end error(msg) else log.warn(msg) if FOOTPRINT then - log.debug(table.concat(status.footprint, '\n')) + log.debug(status.mode) + log.debug(tconcat(status.footprint, '\n')) end return end end -local function checkSLock(status, slock, id, field) - if noder.getIDLength(id) > 20 then - stop(status, 'too long!') - return false - end - local cmark = slock[id] - if not cmark then - cmark = {} - slock[id] = {} - end - if cmark[field or ''] then - return false - end - cmark[field or ''] = true - local right = '' - while field and field ~= '' do - local lastID = noder.getLastID(field) - if not lastID then - break - end - right = lastID .. right - if cmark[right] then - return false - end - field = field:sub(1, - #lastID - 1) - end - return true -end - local function isCallID(field) if not field then return false end - if field:sub(1, 2) == noder.RETURN_INDEX then + if ssub(field, 1, 2) == RETURN_INDEX then return true end return false end -function m.searchRefsByID(status, uri, expect, mode) - local ast = files.getState(uri) +local genercCache = { + mark = {}, + genericCallArgs = {}, + closureCache = {}, +} + +local function flushGeneric() + --清除来自泛型的临时对象 + for _, closure in next, genercCache.closureCache do + local noders = getNoders(closure) + removeID(noders, getID(closure)) + if closure then + local values = closure.values + for i = 1, #values do + local value = values[i] + removeID(noders, getID(value)) + end + end + end + genercCache.mark = {} + genercCache.closureCache = {} + genercCache.genericCallArgs = {} +end + +files.watch(function (ev) + if ev == 'version' then + flushGeneric() + end +end) + +local nodersMapMT = {__index = function (self, uri) + local noders = getNodersByUri(uri) + self[uri] = noders or false + return noders +end} + +local uriMapMT = {__index = function (self, uri) + local t = {} + self[uri] = t + return t +end} + +function m.searchRefsByID(status, suri, expect, mode) + local ast = files.getState(suri) if not ast then return end - local root = ast.ast local searchStep - noder.compileNodes(root) status.id = expect local callStack = status.callStack + local ids = status.ids + local dontCross = 0 + ---@type table<uri, noders> + local nodersMap = setmetatable({}, nodersMapMT) + local frejectMap = setmetatable({}, uriMapMT) + local brejectMap = setmetatable({}, uriMapMT) + local slockMap = setmetatable({}, uriMapMT) + local elockMap = setmetatable({}, uriMapMT) + + local function lockExpanding(elock, id, field) + if not field then + field = '' + end + local locked = elock[id] + if locked and field then + if #locked <= #field then + if ssub(field, -#locked) == locked then + footprint(status, 'elocked:', id, locked, field) + return false + end + end + end + elock[id] = field + return true + end - local slock = status.slock[uri] or {} - local elock = status.elock[uri] or {} - status.slock[uri] = slock - status.elock[uri] = elock + local function releaseExpanding(elock, id, field) + elock[id] = nil + end - local function search(id, field) - local firstID = noder.getFirstID(id) + local function search(uri, id, field) + local firstID = getFirstID(id) if ignoredIDs[firstID] and (field or firstID ~= id) then return end - if not checkSLock(status, slock, id, field) then - footprint(status, 'slocked:', id, field) - return - end + footprint(status, 'search:', id, field) - searchStep(id, field) + searchStep(uri, id, field) footprint(status, 'pop:', id, field) end - local function splitID(id, field) + local function splitID(uri, id, field) if field then return end @@ -348,7 +405,7 @@ function m.searchRefsByID(status, uri, expect, mode) local rightID while true do - local firstID = noder.getHeadID(rightID or id) + local firstID = getHeadID(rightID or id) if not firstID or firstID == id then return end @@ -356,8 +413,8 @@ function m.searchRefsByID(status, uri, expect, mode) if leftID == id then return end - rightID = id:sub(#leftID + 1) - search(leftID, rightID) + rightID = ssub(id, #leftID + 1) + search(uri, leftID, rightID) local isCall = isCallID(firstID) if isCall then break @@ -365,14 +422,23 @@ function m.searchRefsByID(status, uri, expect, mode) end end - local function searchID(id, field) + local function searchID(uri, id, field, sourceUri) if not id then return end + if not nodersMap[uri] then + return + end if field then id = id .. field end - search(id, nil) + ids[id] = true + if slockMap[uri][id] then + footprint(status, 'slocked:', id) + return + end + slockMap[uri][id] = true + search(uri, id, nil) end ---@return parser.guide.object? @@ -388,9 +454,9 @@ function m.searchRefsByID(status, uri, expect, mode) return nil end - local genericCallArgs = {} - local closureCache = {} - local function checkGeneric(source, field) + local genericCallArgs = genercCache.genericCallArgs + local closureCache = genercCache.closureCache + local function checkGeneric(uri, source, field) if not source.isGeneric then return end @@ -401,46 +467,56 @@ function m.searchRefsByID(status, uri, expect, mode) if not call then return end + local id = genercCache.mark[call] + if id == false then + return + end + if id then + searchID(uri, id, field) + return + end - if call.args then - for _, arg in ipairs(call.args) do + local args = call.args + if args then + for i = 1, #args do + local arg = args[i] genericCallArgs[arg] = true end end - local cacheID = noder.getID(source) .. noder.getID(call) + local cacheID = uri .. getID(source) .. getID(call) local closure = closureCache[cacheID] if closure == false then + genercCache.mark[call] = false return end if not closure then closure = generic.createClosure(source, call) closureCache[cacheID] = closure or false if not closure then + genercCache.mark[call] = false return end end - local id = noder.getID(closure) - searchID(id, field) + id = getID(closure) + genercCache.mark[call] = id + searchID(uri, id, field) end - local function checkENV(source, field) + local function checkENV(uri, source, field) if not field then return end if source.special ~= '_G' then return end - local newID = 'g:' .. field:sub(2) - searchID(newID) + local newID = 'g:' .. ssub(field, 2) + searchID(uri, newID) end - local freject = {} - local breject = {} - ---@param ward '"forward"'|'"backward"' ---@param info node.info - local function checkThenPushReject(ward, info) + local function checkThenPushReject(uri, ward, info) local reject = info.reject if not reject then return true @@ -448,11 +524,11 @@ function m.searchRefsByID(status, uri, expect, mode) local checkReject local pushReject if ward == 'forward' then - checkReject = breject - pushReject = freject + checkReject = brejectMap[uri] + pushReject = frejectMap[uri] else - checkReject = freject - pushReject = breject + checkReject = frejectMap[uri] + pushReject = brejectMap[uri] end if checkReject[reject] and checkReject[reject] > 0 then return false @@ -463,16 +539,16 @@ function m.searchRefsByID(status, uri, expect, mode) ---@param ward '"forward"'|'"backward"' ---@param info node.info - local function popReject(ward, info) + local function popReject(uri, ward, info) local reject = info.reject if not reject then return end local popTags if ward == 'forward' then - popTags = freject + popTags = frejectMap[uri] else - popTags = breject + popTags = brejectMap[uri] end popTags[reject] = popTags[reject] - 1 end @@ -515,7 +591,7 @@ function m.searchRefsByID(status, uri, expect, mode) ---@param id string ---@param info node.info local function checkInfoFilter(id, field, info) - for filter in pairs(filters) do + for filter in next, filters do if not filter(id, field) then return false end @@ -525,9 +601,9 @@ function m.searchRefsByID(status, uri, expect, mode) ---@param id string ---@param info node.info - local function checkInfoBeforeForward(id, field, info) + local function checkInfoBeforeForward(uri, id, field, info) pushInfoFilter(id, field, info) - if not checkThenPushReject('forward', info) then + if not checkThenPushReject(uri, 'forward', info) then return false end return true @@ -535,27 +611,29 @@ function m.searchRefsByID(status, uri, expect, mode) ---@param id string ---@param info node.info - local function releaseInfoAfterForward(id, field, info) - popReject('forward', info) + local function releaseInfoAfterForward(uri, id, field, info) + popReject(uri, 'forward', info) releaseInfoFilter(id, field, info) end - local function checkForward(id, node, field) - for forwardID, info in noder.eachForward(node) do - if info and not checkInfoBeforeForward(forwardID, field, info) then + local function checkForward(uri, id, field) + for forwardID, info in eachForward(nodersMap[uri], id) do + if info and not checkInfoBeforeForward(uri, forwardID, field, info) then goto CONTINUE end if not checkInfoFilter(forwardID, field, info) then goto CONTINUE end - local targetUri, targetID = noder.getUriAndID(forwardID) - if targetUri and not files.eq(targetUri, uri) then - crossSearch(status, targetUri, targetID .. (field or ''), mode, uri) + local targetUri, targetID = getUriAndID(forwardID) + if targetUri and targetUri ~= uri then + if dontCross == 0 then + searchID(targetUri, targetID, field, uri) + end else - searchID(targetID or forwardID, field) + searchID(uri, targetID or forwardID, field) end if info then - releaseInfoAfterForward(forwardID, field, info) + releaseInfoAfterForward(uri, forwardID, field, info) end ::CONTINUE:: end @@ -564,16 +642,16 @@ function m.searchRefsByID(status, uri, expect, mode) ---@param id string ---@param field string ---@param info node.info - local function checkInfoBeforeBackward(id, field, info) + local function checkInfoBeforeBackward(uri, id, field, info) if info.deep and mode ~= 'allref' then return false end - if not checkThenPushReject('backward', info) then + if not checkThenPushReject(uri, 'backward', info) then return false end pushInfoFilter(id, field, info) if info.dontCross then - status.dontCross = status.dontCross + 1 + dontCross = dontCross + 1 end return true end @@ -581,278 +659,245 @@ function m.searchRefsByID(status, uri, expect, mode) ---@param id string ---@param field string ---@param info node.info - local function releaseInfoAfterBackward(id, field, info) - popReject('backward', info) + local function releaseInfoAfterBackward(uri, id, field, info) + popReject(uri, 'backward', info) releaseInfoFilter(id, field, info) if info.dontCross then - status.dontCross = status.dontCross - 1 + dontCross = dontCross - 1 end end - local function checkBackward(id, node, field) + local function checkBackward(uri, id, field) if ignoredIDs[id] then return end if mode ~= 'ref' and mode ~= 'field' and mode ~= 'allref' and not field then return end - for backwardID, info in noder.eachBackward(node) do - if info and not checkInfoBeforeBackward(backwardID, field, info) then + for backwardID, info in eachBackward(nodersMap[uri], id) do + if info and not checkInfoBeforeBackward(uri, backwardID, field, info) then goto CONTINUE end if not checkInfoFilter(backwardID, field, info) then goto CONTINUE end - local targetUri, targetID = noder.getUriAndID(backwardID) - if targetUri and not files.eq(targetUri, uri) then - crossSearch(status, targetUri, targetID .. (field or ''), mode, uri) + local targetUri, targetID = getUriAndID(backwardID) + if targetUri and targetUri ~= uri then + if dontCross == 0 then + searchID(targetUri, targetID, field, uri) + end else - searchID(targetID or backwardID, field) + searchID(uri, targetID or backwardID, field) end if info then - releaseInfoAfterBackward(backwardID, field, info) + releaseInfoAfterBackward(uri, backwardID, field, info) end ::CONTINUE:: end end - local function searchSpecial(id, field) + local function searchSpecial(uri, id, field) -- Special rule: ('').XX -> stringlib.XX if id == 'str:' or id == 'dn:string' then if field or mode == 'field' then - searchID('dn:stringlib', field) + searchID(uri, 'dn:stringlib', field) else - searchID('dn:string', field) + searchID(uri, 'dn:string', field) end end end - local function checkRequire(requireName, field) - local tid = 'mainreturn' .. (field or '') + local function checkRequire(uri, requireName, field) + if not requireName then + return + end local uris = ws.findUrisByRequirePath(requireName) - footprint(status, ('require %q:\n%s'):format(requireName, table.concat(uris, '\n'))) - for _, ruri in ipairs(uris) do - if not files.eq(uri, ruri) then - crossSearch(status, ruri, tid, mode, uri) + footprint(status, 'require:', requireName) + for i = 1, #uris do + local ruri = uris[i] + if uri ~= ruri then + searchID(ruri, 'mainreturn', field, uri) end end end - local function searchGlobal(id, node, field) - if id:sub(1, 2) ~= 'g:' then + local function searchGlobal(uri, id, field) + if dontCross ~= 0 then return end - if checkLock(status, id, field) then + if ssub(id, 1, 2) ~= 'g:' then return end - local tid = id .. (field or '') - footprint(status, ('checkGlobal:%s + %s'):format(id, field, tid)) + footprint(status, 'checkGlobal:', id, field) local crossed = {} - if mode == 'def' or mode == 'alldef' then - for _, guri in collector.each('def:' .. id) do - if files.eq(uri, guri) then + if mode == 'def' + or mode == 'alldef' + or mode == 'field' + or field then + for _, guri in ceach('def:' .. id) do + if uri == guri then goto CONTINUE end - crossSearch(status, guri, tid, mode, uri) + searchID(guri, id, field, uri) ::CONTINUE:: end else - for _, guri in collector.each(id) do + for _, guri in ceach(id) do if crossed[guri] then goto CONTINUE end - if mode == 'def' or mode == 'alldef' then + if mode == 'def' or mode == 'alldef' or mode == 'field' then goto CONTINUE end - if files.eq(uri, guri) then + if uri == guri then goto CONTINUE end - crossSearch(status, guri, tid, mode, uri) + searchID(guri, id, field, uri) ::CONTINUE:: end end end - local function searchClass(id, node, field) - if id:sub(1, 3) ~= 'dn:' then + local function searchClass(uri, id, field) + if dontCross ~= 0 then return end - if checkLock(status, id, field) then + if ssub(id, 1, 3) ~= 'dn:' then return end - local tid = id .. (field or '') local sid = id if ignoredIDs[id] or id == 'dn:string' then sid = 'def:' .. sid end - for _, guri in collector.each(sid) do - if not files.eq(uri, guri) then - crossSearch(status, guri, tid, mode, uri) + for _, guri in ceach(sid) do + if uri ~= guri then + searchID(guri, id, field, uri) end end end - local function checkMainReturn(id, node, field) + local function checkMainReturn(uri, id, field) if id ~= 'mainreturn' then return end local calls = vm.getLinksTo(uri) - for _, call in ipairs(calls) do - local curi = guide.getUri(call) - local cid = ('%s%s'):format( - noder.getID(call), - field or '' - ) - if not files.eq(curi, uri) then - crossSearch(status, curi, cid, mode, uri) - end - end - end - - local function lockExpanding(id, field) - local locked = elock[id] - if locked and field then - if #locked <= #field then - if field:sub(-#locked) == locked then - footprint(status, 'elocked:', id, locked, field) - return false - end + for i = 1, #calls do + local call = calls[i] + local curi = getUri(call) + local cid = getID(call) + if curi ~= uri then + searchID(curi, cid, field, uri) end end - elock[id] = field - return true end - local function releaseExpanding(id, field) - elock[id] = nil - end + local function searchNode(uri, id, field) + local noders = nodersMap[uri] + local call = noders.call[id] + local global = isGlobalID(id) + callStack[#callStack+1] = call - local function searchNode(id, node, field) - if node.call then - callStack[#callStack+1] = node.call - end - if field == nil and node.source and not ignoredSources[id] then - for source in noder.eachSource(node) do + if field == nil and not ignoredSources[id] then + for source in eachSource(noders, id) do local force = genericCallArgs[source] - m.pushResult(status, mode, source, force) + pushResult(status, mode, source, force) end end - if node.require then - checkRequire(node.require, field) + local requireName = noders.require[id] + if requireName then + checkRequire(uri, requireName, field) end - if lockExpanding(id, field) then - if node.forward then - checkForward(id, node, field) + local elock = global and elockMap['@global'] or elockMap[uri] + + if lockExpanding(elock, id, field) then + if noders.forward[id] then + checkForward(uri, id, field) end - if node.backward then - checkBackward(id, node, field) + if noders.backward[id] then + checkBackward(uri, id, field) end - releaseExpanding(id, field) + releaseExpanding(elock, id, field) end - if node.source then - checkGeneric(node.source, field) - checkENV(node.source, field) + local source = noders.source[id] + if source then + checkGeneric(uri, source, field) + checkENV(uri, source, field) end if mode == 'allref' or mode == 'alldef' then - checkMainReturn(id, node, field) + checkMainReturn(uri, id, field) end - if node.call then + if call then callStack[#callStack] = nil end return false end - local function searchAnyField(id, field) + local function searchAnyField(uri, id, field) if mode == 'ref' or mode == 'allref' then return end - local lastID = noder.getLastID(id) + local lastID = getLastID(id) if not lastID then return end - local originField = id:sub(#lastID + 1) - if originField == noder.TABLE_KEY - or originField == noder.WEAK_TABLE_KEY then + local originField = ssub(id, #lastID + 1) + if originField == TABLE_KEY + or originField == WEAK_TABLE_KEY then return end - local anyFieldID = lastID .. noder.ANY_FIELD - local anyFieldNode = noder.getNodeByID(root, anyFieldID) - if anyFieldNode then - searchNode(anyFieldID, anyFieldNode, field) - end + local anyFieldID = lastID .. ANY_FIELD + searchNode(uri, anyFieldID, field) end - local function searchWeak(id, field) - local lastID = noder.getLastID(id) + local function searchWeak(uri, id, field) + local lastID = getLastID(id) if not lastID then return end - local originField = id:sub(#lastID + 1) - if originField == noder.WEAK_TABLE_KEY then - local newID = lastID .. noder.TABLE_KEY - local newNode = noder.getNodeByID(root, newID) - if newNode then - searchNode(newID, newNode, field) - end + local originField = ssub(id, #lastID + 1) + if originField == WEAK_TABLE_KEY then + local newID = lastID .. TABLE_KEY + searchNode(uri, newID, field) end - if originField == noder.WEAK_ANY_FIELD then - local newID = lastID .. noder.ANY_FIELD - local newNode = noder.getNodeByID(root, newID) - if newNode then - searchNode(newID, newNode, field) - end + if originField == WEAK_ANY_FIELD then + local newID = lastID .. ANY_FIELD + searchNode(uri, newID, field) end end local stepCount = 0 - local stepMaxCount = 1e3 - local statusMaxCount = 1e4 + local stepMaxCount = 1e4 if mode == 'allref' or mode == 'alldef' then - stepMaxCount = 1e4 - statusMaxCount = 1e5 + stepMaxCount = 1e5 end - function searchStep(id, field) + + function searchStep(uri, id, field) stepCount = stepCount + 1 - status.count = status.count + 1 - if stepCount > stepMaxCount - or status.count > statusMaxCount then + if stepCount > stepMaxCount then stop(status, 'too deep!') return end - searchSpecial(id, field) - local node = noder.getNodeByID(root, id) - if node then - searchNode(id, node, field) - if node.skip and field then - return - end + searchSpecial(uri, id, field) + searchNode(uri, id, field) + if field and nodersMap[uri].skip[id] then + return end - searchGlobal(id, node, field) - searchClass(id, node, field) - splitID(id, field) - searchAnyField(id, field) - searchWeak(id, field) + searchGlobal(uri, id, field) + searchClass(uri, id, field) + splitID(uri, id, field) + searchAnyField(uri, id, field) + searchWeak(uri, id, field) end - search(expect) - - --清除来自泛型的临时对象 - for _, closure in pairs(closureCache) do - noder.removeID(root, noder.getID(closure)) - if closure then - for _, value in ipairs(closure.values) do - noder.removeID(root, noder.getID(value)) - end - end - end + search(suri, expect, nil) end local function prepareSearch(source) @@ -866,44 +911,53 @@ local function prepareSearch(source) if not source then return end - local root = guide.getRoot(source) - if not root then - return - end - noder.compileNodes(root) - local uri = guide.getUri(source) - local id = noder.getID(source) + local noders = getNoders(source) + compilePartNoders(noders, source) + local uri = getUri(source) + local id = getID(source) return uri, id end +local fieldNextTypeMap = util.switch() + : case 'getmethod' + : case 'setmethod' + : case 'getfield' + : case 'setfield' + : case 'getindex' + : case 'setindex' + : call(pushResult) + : getMap() + local function getField(status, source, mode) if source.type == 'table' then - for _, field in ipairs(source) do - m.pushResult(status, mode, field) + for i = 1, #source do + local field = source[i] + pushResult(status, mode, field) end return end if source.type == 'doc.type.ltable' then - for _, field in ipairs(source.fields) do - m.pushResult(status, mode, field) + local fields = source.fields + for i = 1, #fields do + local field = fields[i] + pushResult(status, mode, field) end end if source.type == 'doc.class.name' then - local class = source.parent - for _, field in ipairs(class.fields) do - m.pushResult(status, mode, field.field) + local class = source.parent + local fields = class.fields + for i = 1, #fields do + local field = fields[i] + pushResult(status, mode, field.field) end return end local field = source.next if field then - if field.type == 'getmethod' - or field.type == 'setmethod' - or field.type == 'getfield' - or field.type == 'setfield' - or field.type == 'getindex' - or field.type == 'setindex' then - m.pushResult(status, mode, field) + local ftype = field.type + local pushResultOrNil = fieldNextTypeMap[ftype] + if pushResultOrNil then + pushResultOrNil(status, mode, field) end return end @@ -915,24 +969,18 @@ local function searchAllGlobalByUri(status, mode, uri, fullID) return end local root = ast.ast - noder.compileNodes(root) - local noders = noder.getNoders(root) + --compileAllNodes(root) + local noders = getNoders(root) if fullID then - for id, node in pairs(noders) do - if node.source - and id == fullID then - for source in noder.eachSource(node) do - m.pushResult(status, mode, source) - end - end + for source in eachSource(noders, fullID) do + pushResult(status, mode, source) end else - for id, node in pairs(noders) do - if node.source - and id:sub(1, 2) == 'g:' - and not id:find(noder.SPLIT_CHAR) then - for source in noder.eachSource(node) do - m.pushResult(status, mode, source) + for id in next, noders.source do + if ssub(id, 1, 2) == 'g:' + and not sfind(id, SPLIT_CHAR) then + for source in eachSource(noders, id) do + pushResult(status, mode, source) end end end @@ -951,10 +999,15 @@ end ---@param name string ---@return parser.guide.object[] function m.findGlobals(uri, mode, name) - local status = m.status(mode) + local status = m.status(nil, nil, mode) if name then - local fullID = ('g:%q'):format(name) + local fullID + if type(name) == 'string' then + fullID = sformat('%s%s%s', 'g:', STRING_CHAR, name) + else + fullID = sformat('%s%s%s', 'g:', '', name) + end searchAllGlobalByUri(status, mode, uri, fullID) else searchAllGlobalByUri(status, mode, uri) @@ -973,7 +1026,9 @@ function m.searchRefs(status, source, mode) return end - if checkCache(status, uri, id, mode) then + local cached, makeCache = checkCache(status, uri, id, mode) + + if cached then return end @@ -981,6 +1036,9 @@ function m.searchRefs(status, source, mode) log.debug('searchRefs:', id) end m.searchRefsByID(status, uri, id, mode) + if makeCache then + makeCache() + end end ---搜索对象的field @@ -998,33 +1056,61 @@ function m.searchFields(status, source, mode, field) end if field == '*' then if source.special == '_G' then - if checkCache(status, uri, '*', mode) then + local cached, makeCache = checkCache(status, uri, '*', mode) + if cached then return end searchAllGlobals(status, mode) + if makeCache then + makeCache() + end else - if checkCache(status, uri, id .. '*', mode) then + local cached, makeCache = checkCache(status, uri, id .. '*', mode) + if cached then return end - local newStatus = m.status('field') + local newStatus = m.status(source, field, 'field') m.searchRefsByID(newStatus, uri, id, 'field') - for _, def in ipairs(newStatus.results) do + local results = newStatus.results + for i = 1, #results do + local def = results[i] getField(status, def, mode) end + if makeCache then + makeCache() + end end else if source.special == '_G' then - local fullID = ('g:%q'):format(field) - if checkCache(status, uri, fullID, mode) then + local fullID + if type(field) == 'string' then + fullID = sformat('%s%s%s', 'g:', STRING_CHAR, field) + else + fullID = sformat('%s%s%s', 'g:', '', field) + end + local cahced, makeCache = checkCache(status, uri, fullID, mode) + if cahced then return end m.searchRefsByID(status, uri, fullID, mode) + if makeCache then + makeCache() + end else - local fullID = ('%s%s%q'):format(id, noder.SPLIT_CHAR, field) - if checkCache(status, uri, fullID, mode) then + local fullID + if type(field) == 'string' then + fullID = sformat('%s%s%s', id, STRING_FIELD, field) + else + fullID = sformat('%s%s%s', id, SPLIT_CHAR, field) + end + local cahced, makeCache = checkCache(status, uri, fullID, mode) + if cahced then return end m.searchRefsByID(status, uri, fullID, mode) + if makeCache then + makeCache() + end end end end @@ -1034,26 +1120,23 @@ end ---@field results parser.guide.object[] ---@field rmark table ---@field id string +---@field source parser.guide.object +---@field field string ---创建搜索状态 ---@param mode guide.searchmode ---@return guide.status -function m.status(mode) +function m.status(source, field, mode) local status = { callStack = {}, - crossed = {}, - lock = {}, - slock = {}, - elock = {}, results = {}, rmark = {}, - smark = {}, footprint = {}, - count = 0, - ftag = {}, - btag = {}, - dontCross = 0, - cache = vm.getCache('searcher:' .. mode) + ids = {}, + mode = mode, + source = source, + field = field, + cache = setmetatable(vm.getCache('searcher:' .. mode), uriMapMT), } return status end @@ -1063,7 +1146,7 @@ end ---@param field? string ---@return parser.guide.object[] function m.requestReference(obj, field) - local status = m.status('ref') + local status = m.status(obj, field, 'ref') if field then m.searchFields(status, obj, 'ref', field) @@ -1079,7 +1162,7 @@ end ---@param field? string ---@return parser.guide.object[] function m.requestAllReference(obj, field) - local status = m.status('allref') + local status = m.status(obj, field, 'allref') if field then m.searchFields(status, obj, 'allref', field) @@ -1095,7 +1178,7 @@ end ---@param field? string ---@return parser.guide.object[] function m.requestDefinition(obj, field) - local status = m.status('def') + local status = m.status(obj, field, 'def') if field then m.searchFields(status, obj, 'def', field) @@ -1111,7 +1194,7 @@ end ---@param field? string ---@return parser.guide.object[] function m.requestAllDefinition(obj, field) - local status = m.status('alldef') + local status = m.status(obj, field, 'alldef') if field then m.searchFields(status, obj, 'alldef', field) diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua index 4dd9ac32..6d45b601 100644 --- a/script/core/type-definition.lua +++ b/script/core/type-definition.lua @@ -9,8 +9,8 @@ local infer = require 'core.infer' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = searcher.getUri(a.target) - local u2 = searcher.getUri(b.target) + local u1 = guide.getUri(a.target) + local u2 = guide.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -22,7 +22,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = searcher.getUri(res) + local uri = guide.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else |