From 256b37771a9203ab0c27d5690e35d9a1f9185465 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Tue, 8 Mar 2022 19:41:06 +0800 Subject: update --- script/core/definition.lua | 2 +- script/core/infer.lua | 637 ---------------------------------------- script/core/type-definition.lua | 2 +- script/vm/def.lua | 250 ++++++++++++++++ script/vm/getDef.lua | 256 ---------------- script/vm/infer.lua | 52 ++++ script/vm/init.lua | 2 +- script/vm/node.lua | 5 + 8 files changed, 310 insertions(+), 896 deletions(-) delete mode 100644 script/core/infer.lua create mode 100644 script/vm/def.lua delete mode 100644 script/vm/getDef.lua create mode 100644 script/vm/infer.lua (limited to 'script') diff --git a/script/core/definition.lua b/script/core/definition.lua index a01a6a25..78ebce8f 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -130,7 +130,7 @@ return function (uri, offset) end end - local defs = vm.getAllDefs(source) + local defs = vm.getDefs(source) for _, src in ipairs(defs) do if src.dummy then diff --git a/script/core/infer.lua b/script/core/infer.lua deleted file mode 100644 index 35a054ca..00000000 --- a/script/core/infer.lua +++ /dev/null @@ -1,637 +0,0 @@ -local config = require 'config' -local util = require 'utility' -local vm = require "vm.vm" -local guide = require "parser.guide" - -local CLASS = { 'CLASS' } -local TABLE = { 'TABLE' } -local CACHE = { 'CACHE' } - -local typeSort = { - ['boolean'] = 1, - ['string'] = 2, - ['integer'] = 3, - ['number'] = 4, - ['table'] = 5, - ['function'] = 6, - ['true'] = 101, - ['false'] = 102, -} - -local m = {} - -local function mergeTable(a, b) - if not b then - return - end - for v in pairs(b) do - a[v] = true - end - a[CACHE] = nil -end - -local function isBaseType(source, mark) - return m.hasType(source, 'number', mark) - or m.hasType(source, 'integer', mark) - or m.hasType(source, 'string', mark) -end - -local function searchInferOfUnary(value, infers, mark) - local op = value.op.type - if op == 'not' then - infers['boolean'] = true - return - end - if op == '#' then - if m.hasType(value[1], 'table', mark) - or m.hasType(value[1], 'string', mark) then - infers['integer'] = true - end - return - end - if op == '-' then - if m.hasType(value[1], 'integer', mark) then - infers['integer'] = true - elseif isBaseType(value[1], mark) then - infers['number'] = true - end - return - end - if op == '~' then - if isBaseType(value[1], mark) then - infers['integer'] = true - end - return - end -end - -local function searchInferOfBinary(value, infers, mark) - local op = value.op.type - if op == 'and' then - if m.isTrue(value[1], mark) then - mergeTable(infers, m.searchInfers(value[2], nil, mark)) - else - mergeTable(infers, m.searchInfers(value[1], nil, mark)) - end - return - end - if op == 'or' then - if m.isTrue(value[1], mark) then - mergeTable(infers, m.searchInfers(value[1], nil, mark)) - else - mergeTable(infers, m.searchInfers(value[2], nil, mark)) - end - return - end - -- must return boolean - if op == '==' - or op == '~=' - or op == '<' - or op == '>' - or op == '<=' - or op == '>=' then - infers['boolean'] = true - return - end - -- check number - if op == '<<' - or op == '>>' - or op == '~' - or op == '&' - or op == '|' then - if isBaseType(value[1], mark) - and isBaseType(value[2], mark) then - infers['integer'] = true - end - return - end - if op == '..' then - if isBaseType(value[1], mark) - and isBaseType(value[2], mark) then - infers['string'] = true - end - return - end - if op == '^' - or op == '/' then - if isBaseType(value[1], mark) - and isBaseType(value[2], mark) then - infers['number'] = true - end - return - end - if op == '+' - or op == '-' - or op == '*' - or op == '%' - or op == '//' then - if m.hasType(value[1], 'integer', mark) - and m.hasType(value[2], 'integer', mark) then - infers['integer'] = true - elseif isBaseType(value[1], mark) - and isBaseType(value[2], mark) then - infers['number'] = true - end - return - end -end - -local function searchInferOfValue(value, infers, mark) - if value.type == 'string' then - infers['string'] = true - return true - end - if value.type == 'boolean' then - infers['boolean'] = true - return true - end - if value.type == 'table' then - if value.array then - local node = m.searchAndViewInfers(value.array, nil, mark) - if node ~= 'any' then - local infer = node .. '[]' - infers[infer] = true - end - else - infers['table'] = true - end - return true - end - if value.type == 'integer' then - infers['integer'] = true - return true - end - if value.type == 'number' then - infers['number'] = true - return true - end - if value.type == 'function' then - infers['function'] = true - return true - end - if value.type == 'unary' then - searchInferOfUnary(value, infers, mark) - return true - end - if value.type == 'binary' then - searchInferOfBinary(value, infers, mark) - return true - end - return false -end - -local function searchLiteralOfValue(value, literals, mark) - if value.type == 'string' - or value.type == 'boolean' - or value.type == 'number' - or value.type == 'integer' then - local v = value[1] - if v ~= nil then - literals[v] = true - end - return - end - if value.type == 'unary' then - local op = value.op.type - if op == '-' then - local subLiterals = m.searchLiterals(value[1], nil, mark) - if subLiterals then - for subLiteral in pairs(subLiterals) do - local num = tonumber(subLiteral) - if num then - literals[-num] = true - end - end - end - end - if op == '~' then - local subLiterals = m.searchLiterals(value[1], nil, mark) - if subLiterals then - for subLiteral in pairs(subLiterals) do - local num = math.tointeger(subLiteral) - if num then - literals[~num] = true - end - end - end - end - end -end - -local function bindClassOrType(source) - if not source.bindDocs then - return false - end - for _, doc in ipairs(source.bindDocs) do - if doc.type == 'doc.class' - or doc.type == 'doc.type' then - return true - end - end - return false -end - -local function cleanInfers(uri, infers) - local version = config.get(uri, 'Lua.runtime.version') - local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4' - infers['unknown'] = nil - if infers['number'] then - enableInteger = false - end - if not enableInteger and infers['integer'] then - infers['integer'] = nil - infers['number'] = true - end - -- stringlib 就是 string - if infers['stringlib'] and infers['string'] then - infers['stringlib'] = nil - end - -- 如果有doc标记,则先移除table类型 - if infers[CLASS] then - infers[CLASS] = nil - infers['table'] = nil - end - -- 用doc标记的table,加入table类型 - if infers[TABLE] then - infers[TABLE] = nil - infers['table'] = true - end - if infers['function'] then - for k in pairs(infers) do - if k:sub(1, 4) == 'fun(' then - infers[k] = nil - end - end - end -end - ----合并对象的推断类型 ----@param infers string[] ----@return string -function m.viewInfers(uri, infers) - if infers[CACHE] then - return infers[CACHE] - end - -- 如果有显性的 any ,则直接显示为 any - if infers['any'] then - infers[CACHE] = 'any' - return 'any' - end - local result = {} - local count = 0 - for infer in pairs(infers) do - count = count + 1 - result[count] = infer - end - -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any - if count == 0 then - infers[CACHE] = 'any' - return 'any' - end - table.sort(result, function (a, b) - local sa = typeSort[a] or 100 - local sb = typeSort[b] or 100 - if sa == sb then - return a < b - else - return sa < sb - end - end) - local limit = config.get(uri, 'Lua.hover.enumsLimit') - if limit < 0 then - limit = 0 - end - infers[CACHE] = table.concat(result, '|', 1, math.min(count, limit)) - if count > limit then - infers[CACHE] = ('%s...(+%d)'):format(infers[CACHE], count - limit) - end - return infers[CACHE] -end - ----合并对象的值 ----@param literals string[] ----@return string -function m.viewLiterals(literals) - local result = {} - local count = 0 - for infer in pairs(literals) do - count = count + 1 - result[count] = util.viewLiteral(infer) - end - if count == 0 then - return nil - end - table.sort(result) - local view = table.concat(result, '|') - return view -end - -function m.viewDocName(doc) - if not doc then - return nil - end - if doc.type == 'doc.type' then - local list = {} - for _, tp in ipairs(doc.types) do - list[#list+1] = m.getDocName(tp) - end - return table.concat(list, '|') - end - return m.getDocName(doc) -end - -function m.getDocName(doc) - if not doc then - return nil - end - if doc.type == 'doc.class.name' - or doc.type == 'doc.type.name' then - local name = doc[1] or '?' - if doc.typeGeneric then - return '<' .. name .. '>' - else - return tostring(name) - end - end - if doc.type == 'doc.type.array' then - local nodeName = m.viewDocName(doc.node) or '?' - return nodeName .. '[]' - end - if doc.type == 'doc.type.table' then - local node = m.viewDocName(doc.node) or '?' - local key = m.viewDocName(doc.tkey) or '?' - local value = m.viewDocName(doc.tvalue) or '?' - return ('%s<%s, %s>'):format(node, key, value) - end - if doc.type == 'doc.type.function' then - return m.viewDocFunction(doc) - end - if doc.type == 'doc.type.enum' - or doc.type == 'doc.resume' then - local value = doc[1] or '?' - return tostring(value) - end - if doc.type == 'doc.type.ltable' then - return 'table' - end -end - -function m.viewDocFunction(doc) - if doc.type ~= 'doc.type.function' then - return '' - end - local args = {} - for i, arg in ipairs(doc.args) do - args[i] = ('%s: %s'):format(arg.name[1], arg.extends and m.viewDocName(arg.extends) or 'any') - end - local label = ('fun(%s)'):format(table.concat(args, ', ')) - if #doc.returns > 0 then - local returns = {} - for i, rtn in ipairs(doc.returns) do - returns[i] = m.viewDocName(rtn) - end - label = ('%s:%s'):format(label, table.concat(returns, ', ')) - end - return label -end - ----显示对象的推断类型 ----@param source parser.object ----@param mark table ----@return string -local function searchInfer(source, infers, mark) - if mark[source] then - return - end - mark[source] = true - if bindClassOrType(source) then - return - end - if searchInferOfValue(source, infers, mark) then - return - end - local value = searcher.getObjectValue(source) - if value then - if value.type ~= 'function' - and value.type ~= 'table' - and value.type ~= 'nil' then - searchInferOfValue(value, infers, mark) - end - return - end - -- check LuaDoc - local docName = m.getDocName(source) - if docName and docName ~= 'nil' and docName ~= 'unknown' then - infers[docName] = true - if not vm.isBuiltinType(docName) then - infers[CLASS] = true - end - if docName == 'table' then - infers[TABLE] = true - end - end -end - -local function searchLiteral(source, literals, mark) - if mark[source] then - return - end - mark[source] = true - searchLiteralOfValue(source, literals, mark) - local value = searcher.getObjectValue(source) - if value then - if value.type ~= 'function' - and value.type ~= 'table' then - searchLiteralOfValue(value, literals, mark) - end - return - end -end - -local function getCachedInfers(source, field) - local inferCache = vm.getCache 'infers' - local sourceCache = inferCache[source] - if not sourceCache then - sourceCache = {} - inferCache[source] = sourceCache - end - if not field then - field = '' - end - if sourceCache[field] then - return true, sourceCache[field] - end - local infers = {} - sourceCache[field] = infers - return false, infers -end - ----搜索对象的推断类型 ----@param source parser.object ----@param field? string ----@param mark? table ----@return string[] -function m.searchInfers(source, field, mark) - if not source then - return nil - end - if source.type == 'setlocal' - or source.type == 'getlocal' then - source = source.node - end - local suc, infers = getCachedInfers(source, field) - if suc then - return infers - end - local isParam = source.parent.type == 'funcargs' - local defs = vm.getDefs(source, field) - mark = mark or {} - if not field then - searchInfer(source, infers, mark) - end - for _, def in ipairs(defs) do - if def.typeGeneric and not isParam then - goto CONTINUE - end - if def.type == 'setlocal' then - goto CONTINUE - end - searchInfer(def, infers, mark) - ::CONTINUE:: - end - if source.type == 'doc.type' then - for _, def in ipairs(source.types) do - if def.typeGeneric then - searchInfer(def, infers, mark) - end - end - end - cleanInfers(guide.getUri(source), infers) - return infers -end - ----搜索对象的字面量值 ----@param source parser.object ----@param field? string ----@param mark? table ----@return table -function m.searchLiterals(source, field, mark) - if not source then - return nil - end - local defs = vm.getDefs(source, field) - local literals = {} - mark = mark or {} - if not field then - searchLiteral(source, literals, mark) - end - for _, def in ipairs(defs) do - searchLiteral(def, literals, mark) - end - return literals -end - ----搜索并显示推断值 ----@param source parser.object ----@param field? string ----@return string -function m.searchAndViewLiterals(source, field, mark) - if not source then - return nil - end - local literals = m.searchLiterals(source, field, mark) - if not literals then - return nil - end - local view = m.viewLiterals(literals) - return view -end - ----判断对象的推断值是否是 true ----@param source parser.object ----@param mark? table -function m.isTrue(source, mark) - if not source then - return false - end - mark = mark or {} - if not mark.isTrue then - mark.isTrue = {} - end - if mark.isTrue[source] == nil then - mark.isTrue[source] = false - local literals = m.searchLiterals(source, nil, mark) - if literals then - for literal in pairs(literals) do - if literal ~= false then - mark.isTrue[source] = true - break - end - end - end - end - return mark.isTrue[source] -end - ----判断对象的推断类型是否包含某个类型 -function m.hasType(source, tp, mark) - mark = mark or {} - local infers = m.searchInfers(source, nil, mark) - if not infers then - return false - end - if infers[tp] then - return true - end - if tp == 'function' then - for infer in pairs(infers) do - if infer ~= CACHE and infer:sub(1, 4) == 'fun(' then - return true - end - end - end - return false -end - ----搜索并显示推断类型 ----@param source parser.object ----@param field? string ----@return string -function m.searchAndViewInfers(source, field, mark) - if not source then - return 'any' - end - local infers = m.searchInfers(source, field, mark) - local view = m.viewInfers(guide.getUri(source), infers) - if type(view) == 'boolean' then - log.error('Why view is boolean?', util.dump(infers)) - return 'any' - end - return view -end - ----搜索并显示推断的class ----@param source parser.object ----@return string? -function m.getClass(source) - if not source then - return nil - end - local infers = {} - local defs = vm.getDefs(source) - for _, def in ipairs(defs) do - if def.type == 'doc.class.name' then - if not vm.isBuiltinType(def[1]) then - infers[def[1]] = true - end - end - end - cleanInfers(guide.getUri(source), infers) - local view = m.viewInfers(guide.getUri(source), infers) - if view == 'any' then - return nil - end - return view -end - -return m diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua index 1f021fb3..c9053fab 100644 --- a/script/core/type-definition.lua +++ b/script/core/type-definition.lua @@ -132,7 +132,7 @@ return function (uri, offset) end end - local defs = vm.getAllDefs(source) + local defs = vm.getDefs(source) local values = {} for _, src in ipairs(defs) do local value = searcher.getObjectValue(src) diff --git a/script/vm/def.lua b/script/vm/def.lua new file mode 100644 index 00000000..49a4f7ef --- /dev/null +++ b/script/vm/def.lua @@ -0,0 +1,250 @@ +---@class vm +local vm = require 'vm.vm' +local util = require 'utility' +local compiler = require 'vm.compiler' +local guide = require 'parser.guide' +local localID = require 'vm.local-id' +local globalMgr = require 'vm.global-manager' +local nodeMgr = require 'vm.node' + +local simpleMap + +local function searchGetLocal(source, node, pushResult) + local key = guide.getKeyName(source) + for _, ref in ipairs(node.node.ref) do + if ref.type == 'getlocal' + and guide.isSet(ref.next) + and guide.getKeyName(ref.next) == key then + pushResult(ref.next) + end + end +end + +simpleMap = util.switch() + : case 'local' + : call(function (source, pushResult) + pushResult(source) + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' then + pushResult(ref) + end + end + end + + if source.dummy then + for _, res in ipairs(vm.getDefs(source.method.node)) do + pushResult(res) + end + end + end) + : case 'getlocal' + : case 'setlocal' + : call(function (source, pushResult) + simpleMap['local'](source.node, pushResult) + end) + : case 'field' + : call(function (source, pushResult) + local parent = source.parent + simpleMap[parent.type](parent, pushResult) + end) + : case 'setfield' + : case 'getfield' + : call(function (source, pushResult) + local node = source.node + if node.type == 'getlocal' then + searchGetLocal(source, node, pushResult) + return + end + end) + : case 'getindex' + : case 'setindex' + : call(function (source, pushResult) + local node = source.node + if node.type == 'getlocal' then + searchGetLocal(source, node, pushResult) + end + end) + : case 'goto' + : call(function (source, pushResult) + if source.node then + pushResult(source.node) + end + end) + : getMap() + +local searchFieldMap = util.switch() + : case 'table' + : call(function (node, key, pushResult) + for _, field in ipairs(node) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if guide.getKeyName(field) == key then + pushResult(field) + end + end + end + end) + : case 'global' + ---@param node vm.node + ---@param key string + : call(function (node, key, pushResult) + if node.cate == 'variable' then + local newGlobal = globalMgr.getGlobal('variable', node.name, key) + if newGlobal then + for _, set in ipairs(newGlobal:getSets()) do + pushResult(set) + end + end + end + if node.cate == 'type' then + compiler.getClassFields(node, key, pushResult) + end + end) + : case 'local' + : call(function (node, key, pushResult) + local sources = localID.getSources(node, key) + if sources then + for _, src in ipairs(sources) do + if guide.isSet(src) then + pushResult(src) + end + end + end + end) + : case 'doc.type.table' + : call(function (node, key, pushResult) + for _, field in ipairs(node.fields) do + local fieldKey = field.name + if fieldKey.type == 'doc.field.name' then + if fieldKey[1] == key then + pushResult(field) + end + end + end + end) + : getMap() + +local searchByParentNode +local nodeMap = util.switch() + : case 'field' + : case 'method' + : call(function (source, pushResult) + searchByParentNode(source.parent, pushResult) + end) + : case 'getfield' + : case 'setfield' + : case 'getmethod' + : case 'setmethod' + : case 'getindex' + : case 'setindex' + : call(function (source, pushResult) + local parentNode = compiler.compileNode(source.node) + if not parentNode then + return + end + local key = guide.getKeyName(source) + for pn in nodeMgr.eachNode(parentNode) do + if searchFieldMap[pn.type] then + searchFieldMap[pn.type](pn, key, pushResult) + end + end + end) + : case 'doc.see.field' + : call(function (source, pushResult) + local parentNode = compiler.compileNode(source.parent.name) + if not parentNode then + return + end + for pn in nodeMgr.eachNode(parentNode) do + if searchFieldMap[pn.type] then + searchFieldMap[pn.type](pn, source[1], pushResult) + end + end + end) + : getMap() + + ---@param source parser.object + ---@param pushResult fun(src: parser.object) +local function searchBySimple(source, pushResult) + local simple = simpleMap[source.type] + if simple then + simple(source, pushResult) + end +end + +---@param source parser.object +---@param pushResult fun(src: parser.object) +local function searchByLocalID(source, pushResult) + local idSources = localID.getSources(source) + if not idSources then + return + end + for _, src in ipairs(idSources) do + if guide.isSet(src) then + pushResult(src) + end + end +end + +---@param source parser.object +---@param pushResult fun(src: parser.object) +function searchByParentNode(source, pushResult) + local node = nodeMap[source.type] + if node then + node(source, pushResult) + end +end + +local function searchByNode(source, pushResult) + local node = compiler.compileNode(source) + if not node then + return + end + for n in nodeMgr.eachNode(node) do + if n.type == 'global' and n.cate == 'type' then + for _, set in ipairs(n:getSets()) do + pushResult(set) + end + else + pushResult(n) + end + end +end + +---@param source parser.object +---@return parser.object[] +function vm.getDefs(source) + local results = {} + local mark = {} + + local hasLocal + local function pushResult(src) + if src.type == 'local' and not src.dummy then + if hasLocal then + return + end + hasLocal = true + end + if not mark[src] then + mark[src] = true + if src.type == 'global' then + for _, set in ipairs(src:getSets()) do + pushResult(set) + end + return + end + if guide.isSet(src) + or guide.isLiteral(src) then + results[#results+1] = src + end + end + end + + searchBySimple(source, pushResult) + searchByLocalID(source, pushResult) + searchByParentNode(source, pushResult) + searchByNode(source, pushResult) + + return results +end diff --git a/script/vm/getDef.lua b/script/vm/getDef.lua deleted file mode 100644 index 5edc6c67..00000000 --- a/script/vm/getDef.lua +++ /dev/null @@ -1,256 +0,0 @@ ----@class vm -local vm = require 'vm.vm' -local util = require 'utility' -local compiler = require 'vm.compiler' -local guide = require 'parser.guide' -local localID = require 'vm.local-id' -local globalMgr = require 'vm.global-manager' -local nodeMgr = require 'vm.node' - -local simpleMap - -local function searchGetLocal(source, node, pushResult) - local key = guide.getKeyName(source) - for _, ref in ipairs(node.node.ref) do - if ref.type == 'getlocal' - and guide.isSet(ref.next) - and guide.getKeyName(ref.next) == key then - pushResult(ref.next) - end - end -end - -simpleMap = util.switch() - : case 'local' - : call(function (source, pushResult) - pushResult(source) - if source.ref then - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' then - pushResult(ref) - end - end - end - - if source.dummy then - for _, res in ipairs(vm.getDefs(source.method.node)) do - pushResult(res) - end - end - end) - : case 'getlocal' - : case 'setlocal' - : call(function (source, pushResult) - simpleMap['local'](source.node, pushResult) - end) - : case 'field' - : call(function (source, pushResult) - local parent = source.parent - simpleMap[parent.type](parent, pushResult) - end) - : case 'setfield' - : case 'getfield' - : call(function (source, pushResult) - local node = source.node - if node.type == 'getlocal' then - searchGetLocal(source, node, pushResult) - return - end - end) - : case 'getindex' - : case 'setindex' - : call(function (source, pushResult) - local node = source.node - if node.type == 'getlocal' then - searchGetLocal(source, node, pushResult) - end - end) - : case 'goto' - : call(function (source, pushResult) - if source.node then - pushResult(source.node) - end - end) - : getMap() - -local searchFieldMap = util.switch() - : case 'table' - : call(function (node, key, pushResult) - for _, field in ipairs(node) do - if field.type == 'tablefield' - or field.type == 'tableindex' then - if guide.getKeyName(field) == key then - pushResult(field) - end - end - end - end) - : case 'global' - ---@param node vm.node - ---@param key string - : call(function (node, key, pushResult) - if node.cate == 'variable' then - local newGlobal = globalMgr.getGlobal('variable', node.name, key) - if newGlobal then - for _, set in ipairs(newGlobal:getSets()) do - pushResult(set) - end - end - end - if node.cate == 'type' then - compiler.getClassFields(node, key, pushResult) - end - end) - : case 'local' - : call(function (node, key, pushResult) - local sources = localID.getSources(node, key) - if sources then - for _, src in ipairs(sources) do - if guide.isSet(src) then - pushResult(src) - end - end - end - end) - : case 'doc.type.table' - : call(function (node, key, pushResult) - for _, field in ipairs(node.fields) do - local fieldKey = field.name - if fieldKey.type == 'doc.field.name' then - if fieldKey[1] == key then - pushResult(field) - end - end - end - end) - : getMap() - -local searchByParentNode -local nodeMap = util.switch() - : case 'field' - : case 'method' - : call(function (source, pushResult) - searchByParentNode(source.parent, pushResult) - end) - : case 'getfield' - : case 'setfield' - : case 'getmethod' - : case 'setmethod' - : case 'getindex' - : case 'setindex' - : call(function (source, pushResult) - local parentNode = compiler.compileNode(source.node) - if not parentNode then - return - end - local key = guide.getKeyName(source) - for pn in nodeMgr.eachNode(parentNode) do - if searchFieldMap[pn.type] then - searchFieldMap[pn.type](pn, key, pushResult) - end - end - end) - : case 'doc.see.field' - : call(function (source, pushResult) - local parentNode = compiler.compileNode(source.parent.name) - if not parentNode then - return - end - for pn in nodeMgr.eachNode(parentNode) do - if searchFieldMap[pn.type] then - searchFieldMap[pn.type](pn, source[1], pushResult) - end - end - end) - : getMap() - - ---@param source parser.object - ---@param pushResult fun(src: parser.object) -local function searchBySimple(source, pushResult) - local simple = simpleMap[source.type] - if simple then - simple(source, pushResult) - end -end - ----@param source parser.object ----@param pushResult fun(src: parser.object) -local function searchByLocalID(source, pushResult) - local idSources = localID.getSources(source) - if not idSources then - return - end - for _, src in ipairs(idSources) do - if guide.isSet(src) then - pushResult(src) - end - end -end - ----@param source parser.object ----@param pushResult fun(src: parser.object) -function searchByParentNode(source, pushResult) - local node = nodeMap[source.type] - if node then - node(source, pushResult) - end -end - -local function searchByNode(source, pushResult) - local node = compiler.compileNode(source) - if not node then - return - end - for n in nodeMgr.eachNode(node) do - if n.type == 'global' and n.cate == 'type' then - for _, set in ipairs(n:getSets()) do - pushResult(set) - end - else - pushResult(n) - end - end -end - ----@param source parser.object ----@return parser.object[] -function vm.getDefs(source) - local results = {} - local mark = {} - - local hasLocal - local function pushResult(src) - if src.type == 'local' and not src.dummy then - if hasLocal then - return - end - hasLocal = true - end - if not mark[src] then - mark[src] = true - if src.type == 'global' then - for _, set in ipairs(src:getSets()) do - pushResult(set) - end - return - end - if guide.isSet(src) - or guide.isLiteral(src) then - results[#results+1] = src - end - end - end - - searchBySimple(source, pushResult) - searchByLocalID(source, pushResult) - searchByParentNode(source, pushResult) - searchByNode(source, pushResult) - - return results -end - ----@param source parser.object ----@return parser.object[] -function vm.getAllDefs(source) - return vm.getDefs(source) -end diff --git a/script/vm/infer.lua b/script/vm/infer.lua new file mode 100644 index 00000000..8b2cf5ee --- /dev/null +++ b/script/vm/infer.lua @@ -0,0 +1,52 @@ +local util = require 'utility' +local nodeMgr = require 'vm.node' + +---@class vm.infer-manager +local m = {} + +local viewNodeMap = util.switch() + : case 'boolean' + : case 'string' + : case 'table' + : case 'function' + : case 'number' + : case 'integer' + : call(function (source) + return source.type + end) + : getMap() + +---@param node vm.node +---@return string? +local function viewNode(node) + if viewNodeMap[node.type] then + return viewNodeMap[node.type](node) + end +end + +---@param source parser.object +function m.viewType(source) + local compiler = require 'vm.compiler' + local node = compiler.compileNode(source) + local views = {} + for n in nodeMgr.eachNode(node) do + local view = viewNode(n) + if view then + views[view] = true + end + end + if views['number'] then + views['integer'] = nil + end + local array = {} + for view in pairs(views) do + array[#array+1] = view + end + if #array == 0 then + return 'unknown' + end + table.sort(array) + return table.concat(array, '|') +end + +return m diff --git a/script/vm/init.lua b/script/vm/init.lua index 79cecbb7..2669c849 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -1,8 +1,8 @@ local vm = require 'vm.vm' require 'vm.manager' +require 'vm.def' require 'vm.getDocs' require 'vm.getLibrary' -require 'vm.getDef' require 'vm.getRef' require 'vm.getLinks' return vm diff --git a/script/vm/node.lua b/script/vm/node.lua index a8a900ec..f9dbdf3e 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -5,6 +5,8 @@ local union = require 'vm.union' ---@class vm.node-manager local m = {} +local DUMMY_FUNCTION = function () end + ---@type table m.nodeCache = {} @@ -38,6 +40,9 @@ end ---@return fun():vm.node function m.eachNode(node) + if not node then + return DUMMY_FUNCTION + end if node.type == 'union' then return node:eachNode() end -- cgit v1.2.3