summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/collector.lua7
-rw-r--r--script/core/diagnostics/duplicate-doc-field.lua8
-rw-r--r--script/version.lua2
-rw-r--r--script/vm/compiler.lua23
-rw-r--r--script/vm/node.lua11
-rw-r--r--script/vm/type.lua5
-rw-r--r--script/vm/value.lua2
-rw-r--r--test/diagnostics/common.lua16
-rw-r--r--test/type_inference/init.lua27
9 files changed, 93 insertions, 8 deletions
diff --git a/script/core/collector.lua b/script/core/collector.lua
index a2e3ca08..368a04ec 100644
--- a/script/core/collector.lua
+++ b/script/core/collector.lua
@@ -71,6 +71,8 @@ local DUMMY_FUNCTION = function () end
local function eachOfFolder(nameCollect, scp)
local curi, value
+ ---@return any
+ ---@return uri
local function getNext()
curi, value = next(nameCollect, curi)
if not curi then
@@ -90,6 +92,8 @@ end
local function eachOfLinked(nameCollect, scp)
local curi, value
+ ---@return any
+ ---@return uri
local function getNext()
curi, value = next(nameCollect, curi)
if not curi then
@@ -120,6 +124,8 @@ end
local function eachOfFallback(nameCollect, scp)
local curi, value
+ ---@return any
+ ---@return uri
local function getNext()
curi, value = next(nameCollect, curi)
if not curi then
@@ -146,6 +152,7 @@ end
--- 迭代某个名字的订阅
---@param uri uri
---@param name string
+---@return fun():any, uri
function mt:each(uri, name)
uri = uri or '<fallback>'
local nameCollect = self.collect[name]
diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua
index d4116b9b..78112beb 100644
--- a/script/core/diagnostics/duplicate-doc-field.lua
+++ b/script/core/diagnostics/duplicate-doc-field.lua
@@ -1,5 +1,6 @@
local files = require 'files'
local lang = require 'language'
+local vm = require 'vm.vm'
local function getFieldEventName(doc)
if not doc.extends then
@@ -45,7 +46,12 @@ return function (uri, callback)
mark = {}
elseif doc.type == 'doc.field' then
if mark then
- local name = ('%q'):format(doc.field[1])
+ local name
+ if doc.field.type == 'doc.type' then
+ name = ('[%s]'):format(vm.getInfer(doc.field):view(uri))
+ else
+ name = ('%q'):format(doc.field[1])
+ end
local eventName = getFieldEventName(doc)
if eventName then
name = name .. '|' .. eventName
diff --git a/script/version.lua b/script/version.lua
index fa178564..f5cf5304 100644
--- a/script/version.lua
+++ b/script/version.lua
@@ -1,7 +1,7 @@
local fsu = require 'fs-utility'
local function loadVersion()
- local changelog = fsu.loadFile(ROOT / 'changelog.md')
+ local changelog = fsu.loadFile(ROOT / 'changelog.md'--[[@as fspath]])
if not changelog then
return
end
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 27ba6273..70a4ea92 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -253,7 +253,11 @@ local searchFieldSwitch = util.switch()
end
end)
-
+---@param suri uri
+---@param object vm.global
+---@param key string|vm.global
+---@param ref boolean
+---@param pushResult fun(field: vm.object, isMark?: boolean)
function vm.getClassFields(suri, object, key, ref, pushResult)
local mark = {}
@@ -376,7 +380,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
for _, set in ipairs(sets) do
pushResult(set)
end
- else
+ elseif type(key) == 'string' then
local global = vm.getGlobal('variable', key)
if global then
for _, set in ipairs(global:getSets(suri)) do
@@ -677,7 +681,7 @@ local function bindAs(source)
end
---@param source vm.node
----@param key? any
+---@param key? string|vm.global
---@param pushResult fun(source: parser.object)
function vm.compileByParentNode(source, key, ref, pushResult)
local parentNode = vm.compileNode(source)
@@ -1380,12 +1384,25 @@ local compilerSwitch = util.switch()
return
end
if type(key) == 'table' then
+ ---@cast key vm.node
local uri = guide.getUri(source)
local value = vm.getTableValue(uri, vm.compileNode(source.node), key)
if value then
vm.setNode(source, value):removeOptional()
end
+ for k in key:eachObject() do
+ if k.type == 'global' and k.cate == 'type' then
+ ---@cast k vm.global
+ vm.compileByParentNode(source.node, k, false, function (src)
+ vm.setNode(source, vm.compileNode(src))
+ if src.value then
+ vm.setNode(source, vm.compileNode(src.value)):removeOptional()
+ end
+ end)
+ end
+ end
else
+ ---@cast key string
vm.compileByParentNode(source.node, key, false, function (src)
vm.setNode(source, vm.compileNode(src))
if src.value then
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 65a203f8..24f87a45 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -9,6 +9,7 @@ vm.nodeCache = {}
---@class vm.node
---@field [integer] vm.object
+---@field [vm.object] true
local mt = {}
mt.__index = mt
mt.id = 0
@@ -234,8 +235,7 @@ function mt:narrow(name)
end
for index = #self, 1, -1 do
local c = self[index]
- if (c.type == 'global' and c.cate == 'type' and c.name == name)
- or (c.type == name)
+ if (c.type == name)
or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer'))
or (c.type == 'doc.type.boolean' and name == 'boolean')
or (c.type == 'doc.type.table' and name == 'table')
@@ -243,6 +243,12 @@ function mt:narrow(name)
or (c.type == 'doc.type.function' and name == 'function') then
goto CONTINUE
end
+ if c.type == 'global' and c.cate == 'type' then
+ if (c.name == name)
+ or (c.name == 'integer' and name == 'number') then
+ goto CONTINUE
+ end
+ end
table.remove(self, index)
self[c] = nil
::CONTINUE::
@@ -268,6 +274,7 @@ end
function mt:removeNode(node)
for _, c in ipairs(node) do
if c.type == 'global' and c.cate == 'type' then
+ ---@cast c vm.global
self:remove(c.name)
elseif c.type == 'nil' then
self:remove 'nil'
diff --git a/script/vm/type.lua b/script/vm/type.lua
index 05ada9ea..982f937d 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -84,6 +84,9 @@ function vm.isSubType(uri, child, parent, mark)
return false
end
+ ---@cast child vm.object
+ ---@cast parent vm.object
+
local childName = getNodeName(child)
local parentName = getNodeName(parent)
if childName == 'any'
@@ -146,7 +149,7 @@ end
---@param uri uri
---@param tnode vm.node
----@param knode vm.node
+---@param knode vm.node|string
---@return vm.node?
function vm.getTableValue(uri, tnode, knode)
local result = vm.createNode()
diff --git a/script/vm/value.lua b/script/vm/value.lua
index 13e27e59..e6e2045d 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -72,9 +72,11 @@ local function getUnique(v)
return ('num:%s'):format(v[1])
end
if v.type == 'table' then
+ ---@cast v parser.object
return ('table:%s@%d'):format(guide.getUri(v), v.start)
end
if v.type == 'function' then
+ ---@cast v parser.object
return ('func:%s@%d'):format(guide.getUri(v), v.start)
end
return false
diff --git a/test/diagnostics/common.lua b/test/diagnostics/common.lua
index 0708f63a..a837fc0f 100644
--- a/test/diagnostics/common.lua
+++ b/test/diagnostics/common.lua
@@ -1659,3 +1659,19 @@ k(f())
TEST [[
---@cast <!x!> integer
]]
+
+TEST [[
+---@class A
+
+---@class B
+---@field [integer] A
+---@field [A] true
+]]
+
+TEST [[
+---@class A
+
+---@class B
+---@field [A] A
+---@field [<!A!>] true
+]]
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index bca3a546..434307a2 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -3016,3 +3016,30 @@ if type(x) == 'number' then
print(<?x?>)
end
]]
+
+TEST 'boolean' [[
+---@class A
+---@field [integer] boolean
+local mt
+
+function mt:f()
+ ---@type integer
+ local index
+ local <?x?> = self[index]
+end
+]]
+
+TEST 'boolean' [[
+---@class A
+---@field [B] boolean
+
+---@class B
+
+---@type A
+local a
+
+---@type B
+local b
+
+local <?x?> = a[b]
+]]