diff options
-rw-r--r-- | script/vm/compiler.lua | 148 | ||||
-rw-r--r-- | script/vm/generic-manager.lua | 8 | ||||
-rw-r--r-- | script/vm/getDef.lua | 7 | ||||
-rw-r--r-- | script/vm/node.lua | 56 | ||||
-rw-r--r-- | test/definition/luadoc.lua | 9 |
5 files changed, 121 insertions, 107 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index ff13db1a..aad5a836 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1,8 +1,9 @@ local guide = require 'parser.guide' local util = require 'utility' -local union = require 'vm.union' local localID = require 'vm.local-id' local globalMgr = require 'vm.global-manager' +local nodeMgr = require 'vm.node' +local genericMgr = require 'vm.generic-manager' ---@class parser.object ---@field _compiledNodes boolean @@ -11,52 +12,6 @@ local globalMgr = require 'vm.global-manager' ---@class vm.node.compiler local m = {} -local nodeCache = {} - ----@alias vm.node parser.object | vm.node.union | vm.node.global | vm.node.generic - ----@param a vm.node ----@param b vm.node -function m.mergeNode(a, b) - if not b then - return a - end - if a.type == 'union' then - a:merge(b) - return a - end - return union(a, b) -end - -function m.setNode(source, node) - if not node then - return - end - local me = nodeCache[source] - if not me then - nodeCache[source] = node - return - end - if me == node then - return - end - nodeCache[source] = m.mergeNode(me, node) -end - -function m.eachNode(node) - if node.type == 'union' then - return node:eachNode() - end - local first = true - return function () - if first then - first = false - return node - end - return nil - end -end - local searchFieldMap = util.switch() : case 'table' : call(function (node, key, pushResult) @@ -161,14 +116,14 @@ local function getReturnOfSetMetaTable(source, args) local tbl = args and args[1] local mt = args and args[2] if tbl then - m.setNode(source, m.compileNode(tbl)) + nodeMgr.setNode(source, m.compileNode(tbl)) end if mt then m.compileByParentNode(mt, '__index', function (src) - m.setNode(source, m.compileNode(src)) + nodeMgr.setNode(source, m.compileNode(src)) end) end - return nodeCache[source] + return nodeMgr.nodeCache[source] end local function getReturn(func, index, source, args) @@ -177,7 +132,7 @@ local function getReturn(func, index, source, args) end local node = m.compileNode(func) if node then - for cnode in m.eachNode(node) do + for cnode in nodeMgr.eachNode(node) do if cnode.type == 'function' then local returnNode = getReturnOfFunction(cnode, index) if returnNode and returnNode.type == 'generic' then @@ -187,11 +142,11 @@ local function getReturn(func, index, source, args) end returnNode = returnNode:resolve(argNodes) end - m.setNode(source, returnNode) + nodeMgr.setNode(source, returnNode) end end end - return nodeCache[source] + return nodeMgr.nodeCache[source] end local function bindDocs(source) @@ -201,20 +156,20 @@ local function bindDocs(source) if doc.type == 'doc.type' then if not isParam then hasFounded = true - m.setNode(source, m.compileNode(doc)) + nodeMgr.setNode(source, m.compileNode(doc)) end end if doc.type == 'doc.class' then if source.type == 'local' or (source._globalNode and guide.isSet(source)) then hasFounded = true - m.setNode(source, m.compileNode(doc)) + nodeMgr.setNode(source, m.compileNode(doc)) end end if doc.type == 'doc.param' then if isParam and source[1] == doc.param[1] then hasFounded = true - m.setNode(source, m.compileNode(doc)) + nodeMgr.setNode(source, m.compileNode(doc)) end end end @@ -231,14 +186,14 @@ local function compileByLocalID(source) if src.bindDocs then if bindDocs(src) then hasMarkDoc = true - m.setNode(source, m.compileNode(src)) + nodeMgr.setNode(source, m.compileNode(src)) end end end for _, src in ipairs(sources) do if src.value then if not hasMarkDoc or guide.isLiteral(src.value) then - m.setNode(source, m.compileNode(src.value)) + nodeMgr.setNode(source, m.compileNode(src.value)) end end end @@ -252,7 +207,7 @@ function m.compileByParentNode(source, key, pushResult) if not parentNode then return end - for node in m.eachNode(parentNode) do + for node in nodeMgr.eachNode(parentNode) do local f = searchFieldMap[node.type] if f then f(node, key, pushResult) @@ -298,7 +253,6 @@ local function getFunctionGeneric(func) return func._generic end func._generic = false - local genericMgr = require 'vm.generic-manager' for _, doc in ipairs(func.bindDocs) do if doc.type == 'doc.generic' then if not func._generic then @@ -327,17 +281,17 @@ local compilerMap = util.switch() : case 'doc.type.table' : call(function (source) --localMgr.declareLocal(source) - m.setNode(source, source) + nodeMgr.setNode(source, source) end) : case 'function' : call(function (source) --localMgr.declareLocal(source) - m.setNode(source, source) + nodeMgr.setNode(source, source) if source.bindDocs then for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.overload' then - m.setNode(source, m.compileNode(doc)) + nodeMgr.setNode(source, m.compileNode(doc)) end end end @@ -345,7 +299,7 @@ local compilerMap = util.switch() : case 'local' : call(function (source) --localMgr.declareLocal(source) - m.setNode(source, source) + nodeMgr.setNode(source, source) local hasMarkDoc if source.bindDocs then hasMarkDoc = bindDocs(source) @@ -353,16 +307,16 @@ local compilerMap = util.switch() if source.ref and not hasMarkDoc then for _, ref in ipairs(source.ref) do if ref.type == 'setlocal' then - m.setNode(source, m.compileNode(ref.value)) + nodeMgr.setNode(source, m.compileNode(ref.value)) end end end if source.dummy and not hasMarkDoc then - m.setNode(source, m.compileNode(source.method.node)) + nodeMgr.setNode(source, m.compileNode(source.method.node)) end if source.value then if not hasMarkDoc or guide.isLiteral(source.value) then - m.setNode(source, m.compileNode(source.value)) + nodeMgr.setNode(source, m.compileNode(source.value)) end end -- function x.y(self, ...) --> function x:y(...) @@ -372,13 +326,13 @@ local compilerMap = util.switch() and source.parent[1] == source then local setfield = source.parent.parent.parent if setfield.type == 'setfield' then - m.setNode(source, m.compileNode(setfield.node)) + nodeMgr.setNode(source, m.compileNode(setfield.node)) end end end) : case 'getlocal' : call(function (source) - m.setNode(source, m.compileNode(source.node)) + nodeMgr.setNode(source, m.compileNode(source.node)) end) : case 'setfield' : case 'setmethod' @@ -392,20 +346,20 @@ local compilerMap = util.switch() : call(function (source) compileByLocalID(source) m.compileByParentNode(source.node, guide.getKeyName(source), function (src) - m.setNode(source, m.compileNode(src)) + nodeMgr.setNode(source, m.compileNode(src)) end) end) : case 'tablefield' : case 'tableindex' : call(function (source) if source.value then - m.setNode(source, m.compileNode(source.value)) + nodeMgr.setNode(source, m.compileNode(source.value)) end end) : case 'field' : case 'method' : call(function (source) - m.setNode(source, m.compileNode(source.parent)) + nodeMgr.setNode(source, m.compileNode(source.parent)) end) : case 'function.return' : call(function (source) @@ -427,9 +381,9 @@ local compilerMap = util.switch() end local rtnNode = m.compileNode(rtn) if hasGeneric then - m.setNode(source, generic:getChild(rtnNode)) + nodeMgr.setNode(source, generic:getChild(rtnNode)) else - m.setNode(source, rtnNode) + nodeMgr.setNode(source, rtnNode) end end end @@ -438,7 +392,7 @@ local compilerMap = util.switch() end if func.returns and not hasMarkDoc then for _, rtn in ipairs(func.returns) do - m.setNode(source, selectNode(source, rtn, index)) + nodeMgr.setNode(source, selectNode(source, rtn, index)) end end end) @@ -446,30 +400,34 @@ local compilerMap = util.switch() : call(function (source) local vararg = source.vararg if vararg.type == 'call' then - m.setNode(source, getReturn(vararg.node, source.sindex, source, vararg.args)) + nodeMgr.setNode(source, getReturn(vararg.node, source.sindex, source, vararg.args)) end end) : case 'doc.type' : call(function (source) for _, typeUnit in ipairs(source.types) do - m.setNode(source, m.compileNode(typeUnit)) + nodeMgr.setNode(source, m.compileNode(typeUnit)) end end) + : case 'doc.type.array' + : case(function (source) + nodeMgr.setNode(source, source) + end) : case 'doc.generic.name' : call(function (source) - m.setNode(source, source) + nodeMgr.setNode(source, source) end) : case 'doc.field' : call(function (source) - m.setNode(source, m.compileNode(source.extends)) + nodeMgr.setNode(source, m.compileNode(source.extends)) end) : case 'doc.param' : call(function (source) - m.setNode(source, m.compileNode(source.extends)) + nodeMgr.setNode(source, m.compileNode(source.extends)) end) : case 'doc.vararg' : call(function (source) - m.setNode(source, m.compileNode(source.vararg)) + nodeMgr.setNode(source, m.compileNode(source.vararg)) end) : case '...' : call(function (source) @@ -482,22 +440,22 @@ local compilerMap = util.switch() end for _, doc in ipairs(func.bindDocs) do if doc.type == 'doc.vararg' then - m.setNode(source, m.compileNode(doc)) + nodeMgr.setNode(source, m.compileNode(doc)) end if doc.type == 'doc.param' and doc.param[1] == '...' then - m.setNode(source, m.compileNode(doc)) + nodeMgr.setNode(source, m.compileNode(doc)) end end end) : case 'doc.overload' : call(function (source) - m.setNode(source, m.compileNode(source.overload)) + nodeMgr.setNode(source, m.compileNode(source.overload)) end) : case 'doc.see.name' : call(function (source) local type = globalMgr.getGlobal('type', source[1]) if type then - m.setNode(source, m.compileNode(type)) + nodeMgr.setNode(source, m.compileNode(type)) end end) : getMap() @@ -513,17 +471,17 @@ end ---@param source parser.object local function compileByGlobal(source) if source.type == 'global' then - m.setNode(source, source) + nodeMgr.setNode(source, source) return end if source._globalNode then - m.setNode(source, source._globalNode) + nodeMgr.setNode(source, source._globalNode) if source._globalNode.cate == 'variable' then local hasMarkDoc for _, set in ipairs(source._globalNode:getSets()) do if set.bindDocs then if bindDocs(set) then - m.setNode(source, m.compileNode(set)) + nodeMgr.setNode(source, m.compileNode(set)) hasMarkDoc = true end end @@ -531,7 +489,7 @@ local function compileByGlobal(source) for _, set in ipairs(source._globalNode:getSets()) do if set.value then if not hasMarkDoc or guide.isLiteral(set.value) then - m.setNode(source, m.compileNode(set.value)) + nodeMgr.setNode(source, m.compileNode(set.value)) end end end @@ -543,20 +501,16 @@ end ---@param source parser.object ---@return vm.node function m.compileNode(source) - if nodeCache[source] ~= nil then - return nodeCache[source] + if nodeMgr.nodeCache[source] ~= nil then + return nodeMgr.nodeCache[source] end - nodeCache[source] = false + nodeMgr.nodeCache[source] = false compileByGlobal(source) compileByNode(source) --localMgr.subscribeLocal(source, source._node) - return nodeCache[source] -end - -function m.clearNodeCache() - nodeCache = {} + return nodeMgr.nodeCache[source] end return m diff --git a/script/vm/generic-manager.lua b/script/vm/generic-manager.lua index 4b8b4a3c..4a211076 100644 --- a/script/vm/generic-manager.lua +++ b/script/vm/generic-manager.lua @@ -1,7 +1,7 @@ local createGeneric = require 'vm.generic' -local compiler = require 'vm.compiler' local globalMgr = require 'vm.global-manager' local guide = require 'parser.guide' +local nodeMgr = require 'vm.node' ---@class vm.node.generic-manager ---@field parent parser.object @@ -34,14 +34,14 @@ function mt:resolve(argNodes) if typeUnit.type == 'doc.generic.name' then local key = typeUnit[1] if typeUnit.literal then - for n in compiler.eachNode(node) do + for n in nodeMgr.eachNode(node) do if n.type == 'string' then local type = globalMgr.declareGlobal('type', n[1], guide.getUri(n)) - resolved[key] = compiler.mergeNode(type, resolved[key]) + resolved[key] = nodeMgr.mergeNode(type, resolved[key]) end end else - resolved[key] = compiler.mergeNode(node, resolved[key]) + resolved[key] = nodeMgr.mergeNode(node, resolved[key]) end end end diff --git a/script/vm/getDef.lua b/script/vm/getDef.lua index 0302f9d3..1665154b 100644 --- a/script/vm/getDef.lua +++ b/script/vm/getDef.lua @@ -5,6 +5,7 @@ local compiler = require 'vm.compiler' local guide = require 'parser.guide' local localID = require 'vm.local-id' local globalMgr = require 'vm.global-manager' +local nodeMgr = require 'vm.node' local simpleMap @@ -132,7 +133,7 @@ local nodeMap = util.switch() return end local key = guide.getKeyName(source) - for pn in compiler.eachNode(parentNode) do + for pn in nodeMgr.eachNode(parentNode) do if searchFieldMap[pn.type] then searchFieldMap[pn.type](pn, key, pushResult) end @@ -144,7 +145,7 @@ local nodeMap = util.switch() if not parentNode then return end - for pn in compiler.eachNode(parentNode) do + for pn in nodeMgr.eachNode(parentNode) do if searchFieldMap[pn.type] then searchFieldMap[pn.type](pn, source[1], pushResult) end @@ -189,7 +190,7 @@ local function searchByNode(source, pushResult) if not node then return end - for n in compiler.eachNode(node) do + for n in nodeMgr.eachNode(node) do if n.type == 'global' and n.cate == 'type' then for _, set in ipairs(n:getSets()) do pushResult(set) diff --git a/script/vm/node.lua b/script/vm/node.lua new file mode 100644 index 00000000..79a4d2cb --- /dev/null +++ b/script/vm/node.lua @@ -0,0 +1,56 @@ +local union = require 'vm.union' + +---@alias vm.node parser.object | vm.node.union | vm.node.global | vm.node.generic + +local m = {} + +---@type table<parser.object, vm.node> +m.nodeCache = {} + +---@param a vm.node +---@param b vm.node +function m.mergeNode(a, b) + if not b then + return a + end + if a.type == 'union' then + a:merge(b) + return a + end + return union(a, b) +end + +function m.setNode(source, node) + if not node then + return + end + local me = m.nodeCache[source] + if not me then + m.nodeCache[source] = node + return + end + if me == node then + return + end + m.nodeCache[source] = m.mergeNode(me, node) +end + +function m.eachNode(node) + if node.type == 'union' then + return node:eachNode() + end + local first = true + return function () + if first then + first = false + return node + end + return nil + end +end + +function m.clearNodeCache() + m.nodeCache = {} +end + +return m diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua index 74770794..d1a4aaf0 100644 --- a/test/definition/luadoc.lua +++ b/test/definition/luadoc.lua @@ -297,19 +297,22 @@ print(v1.<?bar1?>) TEST [[ ---@class A -local <!t!> +local t + +t.<!x!> = 1 ---@type A[] local b -local <?<!c!>?> = b[1] +local c = b[1] +c.<?x?> ]] TEST [[ ---@class A local <!t!> ----@type table<number, A> +---@type { [number]: A } local b local <?<!c!>?> = b[1] |