summaryrefslogtreecommitdiff
path: root/script/vm/infer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm/infer.lua')
-rw-r--r--script/vm/infer.lua211
1 files changed, 161 insertions, 50 deletions
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index fabc9828..263b2500 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -8,10 +8,9 @@ local vm = require 'vm.vm'
---@field views table<string, boolean>
---@field cachedView? string
---@field node? vm.node
----@field uri? uri
+---@field _drop table
local mt = {}
mt.__index = mt
-mt._hasNumber = false
mt._hasTable = false
mt._hasClass = false
mt._hasFunctionDef = false
@@ -21,6 +20,8 @@ mt._isLocal = false
vm.NULL = setmetatable({}, mt)
+local LOCK = {}
+
local inferSorted = {
['boolean'] = - 100,
['string'] = - 99,
@@ -43,14 +44,13 @@ local viewNodeSwitch = util.switch()
end)
: case 'number'
: call(function (source, infer)
- infer._hasNumber = true
return source.type
end)
: case 'table'
- : call(function (source, infer)
+ : call(function (source, infer, uri)
if source.type == 'table' then
if #source == 1 and source[1].type == 'varargs' then
- local node = vm.getInfer(source[1]):view()
+ local node = vm.getInfer(source[1]):view(uri)
return ('%s[]'):format(node)
end
end
@@ -76,19 +76,18 @@ local viewNodeSwitch = util.switch()
: case 'global'
: call(function (source, infer)
if source.cate == 'type' then
- infer._hasClass = true
- if source.name == 'number' then
- infer._hasNumber = true
+ if not guide.isBasicType(source.name) then
+ infer._hasClass = true
end
return source.name
end
end)
: case 'doc.type.name'
- : call(function (source, infer)
+ : call(function (source, infer, uri)
if source.signs then
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = vm.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view(uri)
end
return ('%s<%s>'):format(source[1], table.concat(buf, ', '))
else
@@ -96,34 +95,68 @@ local viewNodeSwitch = util.switch()
end
end)
: case 'generic'
- : call(function (source, infer)
- return vm.getInfer(source.proto):view()
+ : call(function (source, infer, uri)
+ return vm.getInfer(source.proto):view(uri)
end)
: case 'doc.generic.name'
: call(function (source, infer)
return ('<%s>'):format(source[1])
end)
: case 'doc.type.array'
- : call(function (source, infer)
+ : call(function (source, infer, uri)
infer._hasClass = true
- local view = vm.getInfer(source.node):view()
+ local view = vm.getInfer(source.node):view(uri)
if source.node.type == 'doc.type' then
view = '(' .. view .. ')'
end
return view .. '[]'
end)
: case 'doc.type.sign'
- : call(function (source, infer)
+ : call(function (source, infer, uri)
infer._hasClass = true
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = vm.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view(uri)
+ end
+ if infer._drop then
+ local node = vm.compileNode(source)
+ for c in node:eachObject() do
+ if guide.isLiteral(c) then
+ infer._drop[c] = true
+ end
+ end
end
return ('%s<%s>'):format(source.node[1], table.concat(buf, ', '))
end)
: case 'doc.type.table'
- : call(function (source, infer)
- infer._hasTable = true
+ : call(function (source, infer, uri)
+ if #source.fields == 0 then
+ infer._hasTable = true
+ return
+ end
+ if infer._drop and infer._drop[source] then
+ infer._hasTable = true
+ return
+ end
+ infer._hasClass = true
+ local buf = {}
+ buf[#buf+1] = '{ '
+ for i, field in ipairs(source.fields) do
+ if i > 1 then
+ buf[#buf+1] = ', '
+ end
+ local key = field.name
+ if key.type == 'doc.type' then
+ buf[#buf+1] = ('[%s]: '):format(vm.getInfer(key):view(uri))
+ elseif type(key[1]) == 'string' then
+ buf[#buf+1] = key[1] .. ': '
+ else
+ buf[#buf+1] = ('[%q]: '):format(key[1])
+ end
+ buf[#buf+1] = vm.getInfer(field.extends):view(uri)
+ end
+ buf[#buf+1] = ' }'
+ return table.concat(buf)
end)
: case 'doc.type.string'
: call(function (source, infer)
@@ -134,8 +167,12 @@ local viewNodeSwitch = util.switch()
: call(function (source, infer)
return ('%q'):format(source[1])
end)
- : case 'doc.type.function'
+ : case 'doc.type.code'
: call(function (source, infer)
+ return ('`%s`'):format(source[1])
+ end)
+ : case 'doc.type.function'
+ : call(function (source, infer, uri)
infer._hasDocFunction = true
local args = {}
local rets = {}
@@ -148,31 +185,53 @@ local viewNodeSwitch = util.switch()
argNode = argNode:copy()
argNode:removeOptional()
end
- args[i] = string.format('%s%s: %s'
+ args[i] = string.format('%s%s%s%s'
, arg.name[1]
, isOptional and '?' or ''
- , vm.getInfer(argNode):view()
+ , arg.name[1] == '...' and '' or ': '
+ , vm.getInfer(argNode):view(uri)
)
end
if #args > 0 then
argView = table.concat(args, ', ')
end
+ local needReturnParen
for i, ret in ipairs(source.returns) do
- rets[i] = vm.getInfer(ret):view()
+ local retType = vm.getInfer(ret):view(uri)
+ if ret.name then
+ if ret.name[1] == '...' then
+ rets[i] = ('%s%s'):format(ret.name[1], retType)
+ else
+ needReturnParen = true
+ rets[i] = ('%s: %s'):format(ret.name[1], retType)
+ end
+ else
+ rets[i] = retType
+ end
end
if #rets > 0 then
- regView = ':' .. table.concat(rets, ', ')
+ if needReturnParen then
+ regView = (':(%s)'):format(table.concat(rets, ', '))
+ else
+ regView = (':%s'):format(table.concat(rets, ', '))
+ end
end
return ('fun(%s)%s'):format(argView, regView)
end)
----@param source parser.object | vm.node
+---@class vm.node
+---@field lastInfer? vm.infer
+
+---@param source vm.object | vm.node
---@return vm.infer
function vm.getInfer(source)
+ ---@type vm.node
local node
if source.type == 'vm.node' then
+ ---@cast source vm.node
node = source
else
+ ---@cast source vm.object
node = vm.compileNode(source)
end
if node.lastInfer then
@@ -180,7 +239,7 @@ function vm.getInfer(source)
end
local infer = setmetatable({
node = node,
- uri = source.type ~= 'vm.node' and guide.getUri(source),
+ _drop = {},
}, mt)
node.lastInfer = infer
@@ -188,9 +247,6 @@ function vm.getInfer(source)
end
function mt:_trim()
- if self._hasNumber then
- self.views['integer'] = nil
- end
if self._hasDocFunction then
if self._hasFunctionDef then
for view in pairs(self.views) do
@@ -205,6 +261,13 @@ function mt:_trim()
if self._hasTable and not self._hasClass then
self.views['table'] = true
end
+ if self.views['number'] then
+ self.views['integer'] = nil
+ end
+ if self.views['boolean'] then
+ self.views['true'] = nil
+ self.views['false'] = nil
+ end
end
---@param uri uri
@@ -214,46 +277,86 @@ function mt:_eraseAlias(uri)
local expandAlias = config.get(uri, 'Lua.hover.expandAlias')
for n in self.node:eachObject() do
if n.type == 'global' and n.cate == 'type' then
+ if LOCK[n.name] then
+ goto CONTINUE
+ end
+ LOCK[n.name] = true
for _, set in ipairs(n:getSets(uri)) do
if set.type == 'doc.alias' then
if expandAlias then
drop[n.name] = true
+ local newInfer = {}
+ for _, ext in ipairs(set.extends.types) do
+ viewNodeSwitch(ext.type, ext, newInfer, uri)
+ end
+ if newInfer._hasTable then
+ self.views['table'] = true
+ end
else
for _, ext in ipairs(set.extends.types) do
- local view = viewNodeSwitch(ext.type, ext, {})
+ local view = viewNodeSwitch(ext.type, ext, {}, uri)
if view and view ~= n.name then
drop[view] = true
end
end
end
end
+ if set.type == 'doc.class' then
+ if set.extends then
+ for _, ext in ipairs(set.extends) do
+ if ext.type == 'doc.extends.name' then
+ local view = ext[1]
+ drop[view] = true
+ end
+ end
+ end
+ end
end
+ LOCK[n.name] = nil
+ ::CONTINUE::
end
end
return drop
end
+---@param uri uri
---@param tp string
---@return boolean
-function mt:hasType(tp)
- self:_computeViews()
+function mt:hasType(uri, tp)
+ self:_computeViews(uri)
return self.views[tp] == true
end
+---@param uri uri
+function mt:hasUnknown(uri)
+ self:_computeViews(uri)
+ return not next(self.views)
+ or self.views['unknown'] == true
+end
+
+---@param uri uri
+function mt:hasAny(uri)
+ self:_computeViews(uri)
+ return self.views['any'] == true
+end
+
+---@param uri uri
---@return boolean
-function mt:hasClass()
- self:_computeViews()
+function mt:hasClass(uri)
+ self:_computeViews(uri)
return self._hasClass == true
end
+---@param uri uri
---@return boolean
-function mt:hasFunction()
- self:_computeViews()
+function mt:hasFunction(uri)
+ self:_computeViews(uri)
return self.views['function'] == true
or self._hasDocFunction == true
end
-function mt:_computeViews()
+---@param uri uri
+function mt:_computeViews(uri)
if self.views then
return
end
@@ -261,7 +364,7 @@ function mt:_computeViews()
self.views = {}
for n in self.node:eachObject() do
- local view = viewNodeSwitch(n.type, n, self)
+ local view = viewNodeSwitch(n.type, n, self, uri)
if view then
self.views[view] = true
end
@@ -270,11 +373,11 @@ function mt:_computeViews()
self:_trim()
end
+---@param uri uri
---@param default? string
----@param uri? uri
---@return string
-function mt:view(default, uri)
- self:_computeViews()
+function mt:view(uri, default)
+ self:_computeViews(uri)
if self.views['any'] then
return 'any'
@@ -282,7 +385,7 @@ function mt:view(default, uri)
local drop
if self._hasClass then
- drop = self:_eraseAlias(uri or self.uri)
+ drop = self:_eraseAlias(uri)
end
local array = {}
@@ -302,7 +405,7 @@ function mt:view(default, uri)
end)
local max = #array
- local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit')
+ local limit = config.get(uri, 'Lua.hover.enumsLimit')
local view
if #array == 0 then
@@ -329,8 +432,9 @@ function mt:view(default, uri)
return view
end
-function mt:eachView()
- self:_computeViews()
+---@param uri uri
+function mt:eachView(uri)
+ self:_computeViews(uri)
return next, self.views
end
@@ -346,7 +450,6 @@ function mt:merge(other)
local infer = setmetatable({
node = vm.createNode(self.node, other.node),
- uri = self.uri,
}, mt)
return infer
@@ -365,7 +468,7 @@ function mt:viewLiterals()
or n.type == 'integer'
or n.type == 'boolean' then
local literal = util.viewLiteral(n[1])
- if not mark[literal] then
+ if literal and not mark[literal] then
literals[#literals+1] = literal
mark[literal] = true
end
@@ -374,7 +477,14 @@ function mt:viewLiterals()
if #literals == 0 then
return nil
end
- table.sort(literals)
+ table.sort(literals, function (a, b)
+ local sa = inferSorted[a] or 0
+ local sb = inferSorted[b] or 0
+ if sa == sb then
+ return a < b
+ end
+ return sa < sb
+ end)
return table.concat(literals, '|')
end
@@ -401,8 +511,9 @@ function mt:viewClass()
return table.concat(class, '|')
end
----@param source parser.object
+---@param source vm.node.object
+---@param uri uri
---@return string?
-function vm.viewObject(source)
- return viewNodeSwitch(source.type, source, {})
+function vm.viewObject(source, uri)
+ return viewNodeSwitch(source.type, source, {}, uri)
end