From 916b8563cc327af32a5c3dccfdb5434711d83377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Thu, 9 Jun 2022 16:32:05 +0800 Subject: view infer must specify uri --- script/core/completion/completion.lua | 12 ++--- script/core/diagnostics/close-non-object.lua | 8 ++-- script/core/diagnostics/no-unknown.lua | 2 +- script/core/diagnostics/not-yieldable.lua | 4 +- script/core/diagnostics/undefined-field.lua | 2 +- script/core/hint.lua | 2 +- script/core/hover/args.lua | 10 ++-- script/core/hover/description.lua | 2 +- script/core/hover/init.lua | 2 +- script/core/hover/label.lua | 6 +-- script/core/hover/return.lua | 5 +- script/core/hover/table.lua | 6 +-- script/core/semantic-tokens.lua | 6 +-- script/vm/infer.lua | 68 ++++++++++++++-------------- test/type_inference/init.lua | 4 +- 15 files changed, 71 insertions(+), 68 deletions(-) diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index d2d0a040..009b4297 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -184,7 +184,7 @@ local function buildFunctionSnip(source, value, oop) end local function buildDetail(source) - local types = vm.getInfer(source):view() + local types = vm.getInfer(source):view(guide.getUri(source)) local literals = vm.getInfer(source):viewLiterals() if literals then return types .. ' = ' .. literals @@ -302,7 +302,7 @@ local function checkLocal(state, word, position, results) if name:sub(1, 1) == '@' then goto CONTINUE end - if vm.getInfer(source):hasFunction() then + if vm.getInfer(source):hasFunction(state.uri) then local defs = vm.getDefs(source) -- make sure `function` is before `doc.type.function` local orders = {} @@ -513,7 +513,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent }) return end - if oop and not vm.getInfer(src):hasFunction() then + if oop and not vm.getInfer(src):hasFunction(state.uri) then return end local literal = guide.getLiteral(value) @@ -1440,7 +1440,7 @@ local function tryCallArg(state, position, results) : string() end enums[#enums+1] = { - label = vm.getInfer(src):view(), + label = vm.getInfer(src):view(state.uri), description = description, kind = define.CompletionItemKind.Function, insertText = insertText, @@ -1819,14 +1819,14 @@ local function buildluaDocOfFunction(func) local returns = {} if func.args then for _, arg in ipairs(func.args) do - args[#args+1] = vm.getInfer(arg):view() + args[#args+1] = vm.getInfer(arg):view(guide.getUri(func)) end end if func.returns then for _, rtns in ipairs(func.returns) do for n = 1, #rtns do if not returns[n] then - returns[n] = vm.getInfer(rtns[n]):view() + returns[n] = vm.getInfer(rtns[n]):view(guide.getUri(func)) end end end diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua index c97014fa..d07aaebe 100644 --- a/script/core/diagnostics/close-non-object.lua +++ b/script/core/diagnostics/close-non-object.lua @@ -25,10 +25,10 @@ return function (uri, callback) return end local infer = vm.getInfer(source.value) - if not infer:hasClass() - and not infer:hasType 'nil' - and not infer:hasType 'table' - and infer:view('any', uri) ~= 'any' then + if not infer:hasClass(uri) + and not infer:hasType(uri, 'nil') + and not infer:hasType(uri, 'table') + and infer:view(uri, 'any') ~= 'any' then callback { start = source.value.start, finish = source.value.finish, diff --git a/script/core/diagnostics/no-unknown.lua b/script/core/diagnostics/no-unknown.lua index 48aab5da..ff9f7a83 100644 --- a/script/core/diagnostics/no-unknown.lua +++ b/script/core/diagnostics/no-unknown.lua @@ -20,7 +20,7 @@ return function (uri, callback) and source.type ~= 'tableindex' then return end - if vm.getInfer(source):view() == 'unknown' then + if vm.getInfer(source):view(uri) == 'unknown' then callback { start = source.start, finish = source.finish, diff --git a/script/core/diagnostics/not-yieldable.lua b/script/core/diagnostics/not-yieldable.lua index a1c84276..055025d4 100644 --- a/script/core/diagnostics/not-yieldable.lua +++ b/script/core/diagnostics/not-yieldable.lua @@ -11,7 +11,7 @@ local function isYieldAble(defs, i) local arg = def.args and def.args[i] if arg then hasFuncDef = true - if vm.getInfer(arg):hasType 'any' + if vm.getInfer(arg):hasType(guide.getUri(def), 'any') or vm.isAsync(arg, true) or arg.type == '...' then return true @@ -22,7 +22,7 @@ local function isYieldAble(defs, i) local arg = def.args and def.args[i] if arg then hasFuncDef = true - if vm.getInfer(arg.extends):hasType 'any' + if vm.getInfer(arg.extends):hasType(guide.getUri(def), 'any') or vm.isAsync(arg.extends, true) then return true end diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index 41fcda48..b03838fd 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -34,7 +34,7 @@ return function (uri, callback) local node = src.node if node then local ok - for view in vm.getInfer(node):eachView() do + for view in vm.getInfer(node):eachView(uri) do if not skipCheckClass[view] then ok = true break diff --git a/script/core/hint.lua b/script/core/hint.lua index f97cdcec..350dc114 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -38,7 +38,7 @@ local function typeHint(uri, results, start, finish) end end await.delay() - local view = vm.getInfer(source):view() + local view = vm.getInfer(source):view(uri) if view == 'any' or view == 'unknown' or view == 'nil' then diff --git a/script/core/hover/args.lua b/script/core/hover/args.lua index c485d9b9..21e7c00f 100644 --- a/script/core/hover/args.lua +++ b/script/core/hover/args.lua @@ -9,7 +9,7 @@ local function asFunction(source) methodDef = true end if methodDef then - args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view 'any') + args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view(guide.getUri(source), 'any')) end if source.args then for i = 1, #source.args do @@ -29,15 +29,15 @@ local function asFunction(source) args[#args+1] = ('%s%s: %s'):format( name, optional and '?' or '', - vm.getInfer(argNode):view('any', guide.getUri(source)) + vm.getInfer(argNode):view(guide.getUri(source), 'any') ) elseif arg.type == '...' then args[#args+1] = ('%s: %s'):format( '...', - vm.getInfer(arg):view 'any' + vm.getInfer(arg):view(guide.getUri(source), 'any') ) else - args[#args+1] = ('%s'):format(vm.getInfer(arg):view 'any') + args[#args+1] = ('%s'):format(vm.getInfer(arg):view(guide.getUri(source), 'any')) end ::CONTINUE:: end @@ -56,7 +56,7 @@ local function asDocFunction(source) args[i] = ('%s%s: %s'):format( name, arg.optional and '?' or '', - arg.extends and vm.getInfer(arg.extends):view 'any' or 'any' + arg.extends and vm.getInfer(arg.extends):view(guide.getUri(source), 'any') or 'any' ) end return args diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index 712ea1ad..e3c3f412 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -152,7 +152,7 @@ local function buildEnumChunk(docType, name) local types = {} local lines = {} for _, tp in ipairs(vm.getDefs(docType)) do - types[#types+1] = vm.getInfer(tp):view() + types[#types+1] = vm.getInfer(tp):view(guide.getUri(docType)) if tp.type == 'doc.type.string' or tp.type == 'doc.type.integer' or tp.type == 'doc.type.boolean' then diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index 7231944a..949156aa 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -39,7 +39,7 @@ local function getHover(source) end local oop - if vm.getInfer(source):view() == 'function' then + if vm.getInfer(source):view(guide.getUri(source)) == 'function' then local defs = vm.getDefs(source) -- make sure `function` is before `doc.type.function` local orders = {} diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index 2bbfe806..e725f6b0 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -33,7 +33,7 @@ local function asDocTypeName(source) return '(class) ' .. doc.class[1] end if doc.type == 'doc.alias' then - return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view()) + return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view(guide.getUri(source))) end end end @@ -42,7 +42,7 @@ end local function asValue(source, title) local name = buildName(source, false) or '' local ifr = vm.getInfer(source) - local type = ifr:view() + local type = ifr:view(guide.getUri(source)) local literal = ifr:viewLiterals() local cont = buildTable(source) local pack = {} @@ -139,7 +139,7 @@ local function asDocFieldName(source) break end end - local view = vm.getInfer(source.extends):view() + local view = vm.getInfer(source.extends):view(guide.getUri(source)) if not class then return ('(field) ?.%s: %s'):format(name, view) end diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index 3d8a94a5..2ee234b6 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -1,4 +1,5 @@ local vm = require 'vm.vm' +local guide = require 'parser.guide' ---@param source parser.object ---@return integer @@ -65,7 +66,7 @@ local function asFunction(source) local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ') local text = ('%s%s'):format( name or '', - vm.getInfer(rtn):view() + vm.getInfer(rtn):view(guide.getUri(source)) ) if i == 1 then returns[i] = (' -> %s'):format(text) @@ -83,7 +84,7 @@ local function asDocFunction(source) end local returns = {} for i, rtn in ipairs(source.returns) do - local rtnText = vm.getInfer(rtn):view() + local rtnText = vm.getInfer(rtn):view(guide.getUri(source)) if i == 1 then returns[#returns+1] = (' -> %s'):format(rtnText) else diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index 16874101..97ace24e 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -30,7 +30,7 @@ local function buildAsHash(uri, keys, nodeMap, reachMax) node:removeOptional() end local ifr = vm.getInfer(node) - local typeView = ifr:view('unknown', uri) + local typeView = ifr:view(uri, 'unknown') local literalView = ifr:viewLiterals() if literalView then lines[#lines+1] = (' %s%s: %s = %s,'):format( @@ -75,7 +75,7 @@ local function buildAsConst(uri, keys, nodeMap, reachMax) node = node:copy() node:removeOptional() end - local typeView = vm.getInfer(node):view('unknown', uri) + local typeView = vm.getInfer(node):view(uri, 'unknown') local literalView = literalMap[key] if literalView then lines[#lines+1] = (' %s%s: %s = %s,'):format( @@ -178,7 +178,7 @@ return function (source) return nil end - for view in vm.getInfer(source):eachView() do + for view in vm.getInfer(source):eachView(uri) do if view == 'string' or vm.isSubType(uri, view, 'string') then return nil diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index 33449013..b16f55fd 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -32,7 +32,7 @@ local Care = util.switch() end options.libGlobals[name] = isLib end - local isFunc = vm.getInfer(source):hasFunction() + local isFunc = vm.getInfer(source):hasFunction(guide.getUri(source)) local type = isFunc and define.TokenTypes['function'] or define.TokenTypes.variable local modifier = isLib and define.TokenModifiers.defaultLibrary or define.TokenModifiers.static @@ -81,7 +81,7 @@ local Care = util.switch() return end end - if vm.getInfer(source):hasFunction() then + if vm.getInfer(source):hasFunction(guide.getUri(source)) then results[#results+1] = { start = source.start, finish = source.finish, @@ -196,7 +196,7 @@ local Care = util.switch() end end -- 6. References to other functions - if vm.getInfer(loc):hasFunction() then + if vm.getInfer(loc):hasFunction(guide.getUri(source)) then results[#results+1] = { start = source.start, finish = source.finish, diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 30682f37..ef8c6f29 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -8,7 +8,6 @@ local vm = require 'vm.vm' ---@field views table ---@field cachedView? string ---@field node? vm.node ----@field uri? uri local mt = {} mt.__index = mt mt._hasNumber = false @@ -47,10 +46,10 @@ local viewNodeSwitch = util.switch() return source.type end) : case 'table' - : call(function (source, infer) + : call(function (source, infer, uri) if source.type == 'table' then if #source == 1 and source[1].type == 'varargs' then - local node = vm.getInfer(source[1]):view() + local node = vm.getInfer(source[1]):view(uri) return ('%s[]'):format(node) end end @@ -86,11 +85,11 @@ local viewNodeSwitch = util.switch() end end) : case 'doc.type.name' - : call(function (source, infer) + : call(function (source, infer, uri) if source.signs then local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = vm.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view(uri) end return ('%s<%s>'):format(source[1], table.concat(buf, ', ')) else @@ -98,28 +97,28 @@ local viewNodeSwitch = util.switch() end end) : case 'generic' - : call(function (source, infer) - return vm.getInfer(source.proto):view() + : call(function (source, infer, uri) + return vm.getInfer(source.proto):view(uri) end) : case 'doc.generic.name' : call(function (source, infer) return ('<%s>'):format(source[1]) end) : case 'doc.type.array' - : call(function (source, infer) + : call(function (source, infer, uri) infer._hasClass = true - local view = vm.getInfer(source.node):view() + local view = vm.getInfer(source.node):view(uri) if source.node.type == 'doc.type' then view = '(' .. view .. ')' end return view .. '[]' end) : case 'doc.type.sign' - : call(function (source, infer) + : call(function (source, infer, uri) infer._hasClass = true local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = vm.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view(uri) end return ('%s<%s>'):format(source.node[1], table.concat(buf, ', ')) end) @@ -137,7 +136,7 @@ local viewNodeSwitch = util.switch() return ('%q'):format(source[1]) end) : case 'doc.type.function' - : call(function (source, infer) + : call(function (source, infer, uri) infer._hasDocFunction = true local args = {} local rets = {} @@ -153,14 +152,14 @@ local viewNodeSwitch = util.switch() args[i] = string.format('%s%s: %s' , arg.name[1] , isOptional and '?' or '' - , vm.getInfer(argNode):view() + , vm.getInfer(argNode):view(uri) ) end if #args > 0 then argView = table.concat(args, ', ') end for i, ret in ipairs(source.returns) do - rets[i] = vm.getInfer(ret):view() + rets[i] = vm.getInfer(ret):view(uri) end if #rets > 0 then regView = ':' .. table.concat(rets, ', ') @@ -182,7 +181,6 @@ function vm.getInfer(source) end local infer = setmetatable({ node = node, - uri = source.type ~= 'vm.node' and guide.getUri(source), }, mt) node.lastInfer = infer @@ -222,14 +220,14 @@ function mt:_eraseAlias(uri) drop[n.name] = true local newInfer = {} for _, ext in ipairs(set.extends.types) do - viewNodeSwitch(ext.type, ext, newInfer) + viewNodeSwitch(ext.type, ext, newInfer, uri) end if newInfer._hasTable then self.views['table'] = true end else for _, ext in ipairs(set.extends.types) do - local view = viewNodeSwitch(ext.type, ext, {}) + local view = viewNodeSwitch(ext.type, ext, {}, uri) if view and view ~= n.name then drop[view] = true end @@ -242,27 +240,31 @@ function mt:_eraseAlias(uri) return drop end +---@param uri uri ---@param tp string ---@return boolean -function mt:hasType(tp) - self:_computeViews() +function mt:hasType(uri, tp) + self:_computeViews(uri) return self.views[tp] == true end +---@param uri uri ---@return boolean -function mt:hasClass() - self:_computeViews() +function mt:hasClass(uri) + self:_computeViews(uri) return self._hasClass == true end +---@param uri uri ---@return boolean -function mt:hasFunction() - self:_computeViews() +function mt:hasFunction(uri) + self:_computeViews(uri) return self.views['function'] == true or self._hasDocFunction == true end -function mt:_computeViews() +---@param uri uri +function mt:_computeViews(uri) if self.views then return end @@ -270,7 +272,7 @@ function mt:_computeViews() self.views = {} for n in self.node:eachObject() do - local view = viewNodeSwitch(n.type, n, self) + local view = viewNodeSwitch(n.type, n, self, uri) if view then self.views[view] = true end @@ -279,11 +281,11 @@ function mt:_computeViews() self:_trim() end +---@param uri uri ---@param default? string ----@param uri? uri ---@return string -function mt:view(default, uri) - self:_computeViews() +function mt:view(uri, default) + self:_computeViews(uri) if self.views['any'] then return 'any' @@ -291,7 +293,7 @@ function mt:view(default, uri) local drop if self._hasClass then - drop = self:_eraseAlias(uri or self.uri) + drop = self:_eraseAlias(uri) end local array = {} @@ -311,7 +313,7 @@ function mt:view(default, uri) end) local max = #array - local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit') + local limit = config.get(uri, 'Lua.hover.enumsLimit') local view if #array == 0 then @@ -338,8 +340,9 @@ function mt:view(default, uri) return view end -function mt:eachView() - self:_computeViews() +---@param uri uri +function mt:eachView(uri) + self:_computeViews(uri) return next, self.views end @@ -355,7 +358,6 @@ function mt:merge(other) local infer = setmetatable({ node = vm.createNode(self.node, other.node), - uri = self.uri, }, mt) return infer diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 41421ea9..56f57100 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -31,9 +31,9 @@ function TEST(wanted) files.setText('', newScript) local source = getSource(catched['?'][1][1]) assert(source) - local result = vm.getInfer(source):view() + local result = vm.getInfer(source):view('') if wanted ~= result then - vm.getInfer(source):view() + vm.getInfer(source):view('') end assert(wanted == result) files.remove('') -- cgit v1.2.3