summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/completion/completion.lua6
-rw-r--r--script/core/diagnostics/no-implicit-any.lua2
-rw-r--r--script/core/hint.lua3
-rw-r--r--script/core/hover/arg.lua10
-rw-r--r--script/core/hover/init.lua2
-rw-r--r--script/core/hover/label.lua10
-rw-r--r--script/core/hover/return.lua4
-rw-r--r--script/vm/field.lua5
-rw-r--r--script/vm/infer.lua161
-rw-r--r--script/vm/union.lua2
-rw-r--r--test/type_inference/init.lua4
11 files changed, 112 insertions, 97 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua
index f54bd2d1..133b262d 100644
--- a/script/core/completion/completion.lua
+++ b/script/core/completion/completion.lua
@@ -167,7 +167,7 @@ local function buildDetail(source)
if source.type == 'dummy' then
return
end
- local types = infer.viewType(source)
+ local types = infer.getInfer(source):view()
local literals = infer.viewLiterals(source)
if literals then
return types .. ' = ' .. literals
@@ -1819,14 +1819,14 @@ local function buildluaDocOfFunction(func)
local returns = {}
if func.args then
for _, arg in ipairs(func.args) do
- args[#args+1] = infer.viewType(arg)
+ args[#args+1] = infer.getInfer(arg):view()
end
end
if func.returns then
for _, rtns in ipairs(func.returns) do
for n = 1, #rtns do
if not returns[n] then
- returns[n] = infer.viewType(rtns[n])
+ returns[n] = infer.getInfer(rtns[n]):view()
end
end
end
diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua
index 47f1b997..5c14d211 100644
--- a/script/core/diagnostics/no-implicit-any.lua
+++ b/script/core/diagnostics/no-implicit-any.lua
@@ -20,7 +20,7 @@ return function (uri, callback)
and source.type ~= 'tableindex' then
return
end
- if infer.viewType(source) == 'any' then
+ if infer.getInfer(source):view() == 'unknown' then
callback {
start = source.start,
finish = source.finish,
diff --git a/script/core/hint.lua b/script/core/hint.lua
index 3b5db3e5..51842126 100644
--- a/script/core/hint.lua
+++ b/script/core/hint.lua
@@ -41,8 +41,9 @@ local function typeHint(uri, results, start, finish)
end
end
await.delay()
- local view = infer.viewType(source)
+ local view = infer.getInfer(source):view()
if view == 'any'
+ or view == 'unknown'
or view == 'nil' then
return
end
diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua
index c9c81a85..7611a895 100644
--- a/script/core/hover/arg.lua
+++ b/script/core/hover/arg.lua
@@ -21,7 +21,7 @@ local function asFunction(source, oop)
methodDef = true
end
if methodDef then
- args[#args+1] = ('self: %s'):format(infer.viewType(parent.node))
+ args[#args+1] = ('self: %s'):format(infer.getInfer(parent.node))
end
if source.args then
for i = 1, #source.args do
@@ -34,15 +34,15 @@ local function asFunction(source, oop)
args[#args+1] = ('%s%s: %s'):format(
name,
optionalArg(arg) and '?' or '',
- infer.viewType(arg, 'any')
+ infer.getInfer(arg):view 'any'
)
elseif arg.type == '...' then
args[#args+1] = ('%s: %s'):format(
'...',
- infer.viewType(arg, 'any')
+ infer.getInfer(arg):view 'any'
)
else
- args[#args+1] = ('%s'):format(infer.viewType(arg, 'any'))
+ args[#args+1] = ('%s'):format(infer.getInfer(arg):view 'any')
end
::CONTINUE::
end
@@ -65,7 +65,7 @@ local function asDocFunction(source, oop)
args[i] = ('%s%s: %s'):format(
name,
arg.optional and '?' or '',
- arg.extends and infer.viewType(arg.extends) or 'any'
+ arg.extends and infer.getInfer(arg.extends):view 'any' or 'any'
)
end
if oop then
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
index fdfbd73d..bc2f40eb 100644
--- a/script/core/hover/init.lua
+++ b/script/core/hover/init.lua
@@ -40,7 +40,7 @@ local function getHover(source)
end
local oop
- if infer.viewType(source) == 'function' then
+ if infer.getInfer(source):view() == 'function' then
local hasFunc
for _, def in ipairs(vm.getDefs(source)) do
if guide.isOOP(def) then
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
index 01dd1143..150f0f24 100644
--- a/script/core/hover/label.lua
+++ b/script/core/hover/label.lua
@@ -35,7 +35,7 @@ local function asDocTypeName(source)
end
if doc.type == 'doc.alias.name' then
local extends = doc.parent.extends
- return lang.script('HOVER_EXTENDS', infer.viewType(extends))
+ return lang.script('HOVER_EXTENDS', infer.getInfer(extends):view())
end
end
end
@@ -43,12 +43,12 @@ end
---@async
local function asValue(source, title)
local name = buildName(source, false) or ''
- local type = infer.viewType(source)
+ local type = infer.getInfer(source):view()
local literal = infer.viewLiterals(source)
local cont
- if not infer.hasType(source, 'string')
+ if not infer.getInfer(source):hasType 'string'
and not type:find('%[%]$') then
- if infer.hasType(source, 'table') then
+ if infer.getInfer(source):hasType 'table' then
cont = buildTable(source)
end
end
@@ -131,7 +131,7 @@ local function asDocFieldName(source)
break
end
end
- local view = infer.viewType(docField.extends)
+ local view = infer.getInfer(docField.extends):view()
if not class then
return ('field ?.%s: %s'):format(name, view)
end
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
index e48febf3..cb8fa76f 100644
--- a/script/core/hover/return.lua
+++ b/script/core/hover/return.lua
@@ -67,7 +67,7 @@ local function asFunction(source)
local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ')
local text = ('%s%s%s'):format(
name or '',
- infer.viewType(rtn),
+ infer.getInfer(rtn):view(),
doc and doc.optional and '?' or ''
)
if i == 1 then
@@ -87,7 +87,7 @@ local function asDocFunction(source)
local returns = {}
for i, rtn in ipairs(source.returns) do
local rtnText = ('%s%s'):format(
- infer.viewType(rtn),
+ infer.getInfer(rtn):view(),
rtn.optional and '?' or ''
)
if i == 1 then
diff --git a/script/vm/field.lua b/script/vm/field.lua
index 92448bb3..c30e112d 100644
--- a/script/vm/field.lua
+++ b/script/vm/field.lua
@@ -23,7 +23,10 @@ local function searchByNode(source, pushResult)
if not node then
return
end
- searchNodeSwitch(node.type, node, pushResult)
+
+ for n in nodeMgr.eachNode(node) do
+ searchNodeSwitch(n.type, n, pushResult)
+ end
end
---@param source parser.object
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 6457696a..31a08b74 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -7,6 +7,22 @@ local compiler = require 'vm.compiler'
---@class vm.infer-manager
local m = {}
+---@class vm.infer
+---@field views table<string, boolean>
+---@field source? parser.object
+---@field cachedView? string
+local mt = {}
+mt.__index = mt
+mt.hasNumber = false
+mt.hasTable = false
+mt.hasClass = false
+mt.isParam = false
+mt.isLocal = false
+mt.hasDocFunction = false
+mt.expandAlias = false
+
+local nullInfer = setmetatable({ views = {} }, mt)
+
local inferSorted = {
['boolean'] = - 100,
['string'] = - 99,
@@ -23,44 +39,44 @@ local viewNodeSwitch = util.switch()
: case 'string'
: case 'function'
: case 'integer'
- : call(function (source, options)
+ : call(function (source, infer)
return source.type
end)
: case 'number'
- : call(function (source, options)
- options['hasNumber'] = true
+ : call(function (source, infer)
+ infer.hasNumber = true
return source.type
end)
: case 'table'
- : call(function (source, options)
- options['hasTable'] = true
+ : call(function (source, infer)
+ infer.hasTable = true
end)
: case 'local'
- : call(function (source, options)
+ : call(function (source, infer)
if source.parent == 'funcargs' then
- options['isParam'] = true
+ infer.isParam = true
else
- options['isLocal'] = true
+ infer.isLocal = true
end
end)
: case 'global'
- : call(function (source, options)
+ : call(function (source, infer)
if source.cate == 'type' then
- options['hasClass'] = true
+ infer.hasClass = true
return source.name
end
end)
: case 'doc.type.integer'
- : call(function (source, options)
+ : call(function (source, infer)
return ('%d'):format(source[1])
end)
: case 'doc.type.name'
- : call(function (source, options)
- options['hasClass'] = true
+ : call(function (source, infer)
+ infer.hasClass = true
if source.signs then
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = m.viewType(sign)
+ buf[i] = m.getInfer(sign):view()
end
return ('%s<%s>'):format(source[1], table.concat(buf, ', '))
else
@@ -68,25 +84,25 @@ local viewNodeSwitch = util.switch()
end
end)
: case 'doc.generic.name'
- : call(function (source, options)
+ : call(function (source, infer)
return ('<%s>'):format(source[1])
end)
: case 'doc.type.array'
- : call(function (source, options)
- options['hasClass'] = true
- return m.viewType(source.node) .. '[]'
+ : call(function (source, infer)
+ infer.hasClass = true
+ return m.getInfer(source.node):view() .. '[]'
end)
: case 'doc.type.table'
- : call(function (source, options)
- options['hasTable'] = true
+ : call(function (source, infer)
+ infer.hasTable = true
end)
: case 'doc.type.string'
- : call(function (source, options)
+ : call(function (source, infer)
return ('%q'):format(source[1])
end)
: case 'doc.type.function'
- : call(function (source, options)
- options['hasDocFunction'] = true
+ : call(function (source, infer)
+ infer.hasDocFunction = true
local args = {}
local rets = {}
local argView = ''
@@ -95,7 +111,7 @@ local viewNodeSwitch = util.switch()
args[i] = string.format('%s%s: %s'
, arg.name[1]
, arg.optional and '?' or ''
- , m.viewType(arg)
+ , m.getInfer(arg):view()
)
end
if #args > 0 then
@@ -103,7 +119,7 @@ local viewNodeSwitch = util.switch()
end
for i, ret in ipairs(source.returns) do
rets[i] = string.format('%s%s'
- , m.viewType(ret)
+ , m.getInfer(ret):view()
, ret.optional and '?' or ''
)
end
@@ -113,24 +129,20 @@ local viewNodeSwitch = util.switch()
return ('fun(%s)%s'):format(argView, regView)
end)
----@param node vm.node
----@return string?
-local function viewNode(node, options)
- return viewNodeSwitch(node.type, node, options)
-end
-
-local function eraseAlias(node, viewMap, options)
+---@param infer vm.infer
+local function eraseAlias(infer)
+ local node = compiler.compileNode(infer.source)
for n in nodeMgr.eachNode(node) do
if n.type == 'global' and n.cate == 'type' then
for _, set in ipairs(n:getSets()) do
if set.type == 'doc.alias' then
- if options['expandAlias'] then
- viewMap[n.name] = nil
+ if infer.expandAlias then
+ infer.views[n.name] = nil
else
for _, ext in ipairs(set.extends.types) do
- local view = viewNode(ext, {})
+ local view = viewNodeSwitch(ext.type, ext, {})
if view and view ~= n.name then
- viewMap[view] = nil
+ infer.views[view] = nil
end
end
end
@@ -141,71 +153,67 @@ local function eraseAlias(node, viewMap, options)
end
---@param source parser.object
----@return table<string, boolean>
----@return table<string, boolean>
-function m.getViews(source)
+---@return vm.infer
+function m.getInfer(source)
local node = compiler.compileNode(source)
if not node then
- return {}
+ return nullInfer
end
- if node.type == 'union' and node.lastViews then
- return node.lastViews
+ if node.type == 'union' and node.lastInfer then
+ return node.lastInfer
end
- local views = {}
- local options = {}
- options['expandAlias'] = config.get(guide.getUri(source), 'Lua.hover.expandAlias')
+ local infer = setmetatable({
+ source = source,
+ views = {}
+ }, mt)
+ infer.expandAlias = config.get(guide.getUri(source), 'Lua.hover.expandAlias')
if node.type == 'union' then
- node.lastViews = views
+ node.lastInfer = infer
end
for n in nodeMgr.eachNode(node) do
- local view = viewNode(n, options)
+ local view = viewNodeSwitch(n.type, n, infer)
if view then
- views[view] = true
+ infer.views[view] = true
end
end
- if options['hasNumber'] then
- views['integer'] = nil
+ if infer.hasNumber then
+ infer.views['integer'] = nil
end
- if options['hasDocFunction'] then
- views['function'] = nil
+ if infer.hasDocFunction then
+ infer.views['function'] = nil
end
- if options['hasTable'] and not options['hasClass'] then
- views['table'] = true
+ if infer.hasTable and not infer.hasClass then
+ infer.views['table'] = true
end
- if options['hasClass'] then
- eraseAlias(node, views, options)
+ if infer.hasClass then
+ eraseAlias(infer)
end
- return views, options
+ return infer
end
----@param source parser.object
---@param tp string
---@return boolean
-function m.hasType(source, tp)
- local views = m.getViews(source)
-
- if views[tp] then
- return true
- end
-
- return false
+function mt:hasType(tp)
+ return self.views[tp] == true
end
----@param source parser.object
+---@param default? string
---@return string
-function m.viewType(source, default)
- local views = m.getViews(source)
-
- if views['any'] then
+function mt:view(default)
+ if self.views['any'] then
return 'any'
end
- if not next(views) then
+ if not next(self.views) then
return default or 'unknown'
end
+ if self.cachedView then
+ return self.cachedView
+ end
+
local array = {}
- for view in pairs(views) do
+ for view in pairs(self.views) do
array[#array+1] = view
end
@@ -219,7 +227,7 @@ function m.viewType(source, default)
end)
local max = #array
- local limit = config.get(guide.getUri(source), 'Lua.hover.enumsLimit')
+ local limit = config.get(guide.getUri(self.source), 'Lua.hover.enumsLimit')
if max > limit then
local view = string.format('%s...(+%d)'
@@ -227,13 +235,16 @@ function m.viewType(source, default)
, max - limit
)
+ self.cachedView = view
+
return view
else
local view = table.concat(array, '|')
+ self.cachedView = view
+
return view
end
-
end
---@param source parser.object
diff --git a/script/vm/union.lua b/script/vm/union.lua
index 183f3440..10aebda1 100644
--- a/script/vm/union.lua
+++ b/script/vm/union.lua
@@ -5,7 +5,7 @@ local mt = {}
mt.__index = mt
mt.type = 'union'
mt.optional = nil
-mt.lastViews = nil
+mt.lastInfer = nil
---@param me parser.object
---@param node vm.node
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index e687afe1..1a2c4593 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -31,9 +31,9 @@ function TEST(wanted)
files.setText('', newScript)
local source = getSource(catched['?'][1][1])
assert(source)
- local result = infer.viewType(source)
+ local result = infer.getInfer(source):view()
if wanted ~= result then
- infer.viewType(source)
+ infer.getInfer(source):view()
end
assert(wanted == result)
files.remove('')