diff options
-rw-r--r-- | script/core/collector.lua | 7 | ||||
-rw-r--r-- | script/core/diagnostics/duplicate-doc-field.lua | 8 | ||||
-rw-r--r-- | script/version.lua | 2 | ||||
-rw-r--r-- | script/vm/compiler.lua | 23 | ||||
-rw-r--r-- | script/vm/node.lua | 11 | ||||
-rw-r--r-- | script/vm/type.lua | 5 | ||||
-rw-r--r-- | script/vm/value.lua | 2 | ||||
-rw-r--r-- | test/diagnostics/common.lua | 16 | ||||
-rw-r--r-- | test/type_inference/init.lua | 27 |
9 files changed, 93 insertions, 8 deletions
diff --git a/script/core/collector.lua b/script/core/collector.lua index a2e3ca08..368a04ec 100644 --- a/script/core/collector.lua +++ b/script/core/collector.lua @@ -71,6 +71,8 @@ local DUMMY_FUNCTION = function () end local function eachOfFolder(nameCollect, scp) local curi, value + ---@return any + ---@return uri local function getNext() curi, value = next(nameCollect, curi) if not curi then @@ -90,6 +92,8 @@ end local function eachOfLinked(nameCollect, scp) local curi, value + ---@return any + ---@return uri local function getNext() curi, value = next(nameCollect, curi) if not curi then @@ -120,6 +124,8 @@ end local function eachOfFallback(nameCollect, scp) local curi, value + ---@return any + ---@return uri local function getNext() curi, value = next(nameCollect, curi) if not curi then @@ -146,6 +152,7 @@ end --- 迭代某个名字的订阅 ---@param uri uri ---@param name string +---@return fun():any, uri function mt:each(uri, name) uri = uri or '<fallback>' local nameCollect = self.collect[name] diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua index d4116b9b..78112beb 100644 --- a/script/core/diagnostics/duplicate-doc-field.lua +++ b/script/core/diagnostics/duplicate-doc-field.lua @@ -1,5 +1,6 @@ local files = require 'files' local lang = require 'language' +local vm = require 'vm.vm' local function getFieldEventName(doc) if not doc.extends then @@ -45,7 +46,12 @@ return function (uri, callback) mark = {} elseif doc.type == 'doc.field' then if mark then - local name = ('%q'):format(doc.field[1]) + local name + if doc.field.type == 'doc.type' then + name = ('[%s]'):format(vm.getInfer(doc.field):view(uri)) + else + name = ('%q'):format(doc.field[1]) + end local eventName = getFieldEventName(doc) if eventName then name = name .. '|' .. eventName diff --git a/script/version.lua b/script/version.lua index fa178564..f5cf5304 100644 --- a/script/version.lua +++ b/script/version.lua @@ -1,7 +1,7 @@ local fsu = require 'fs-utility' local function loadVersion() - local changelog = fsu.loadFile(ROOT / 'changelog.md') + local changelog = fsu.loadFile(ROOT / 'changelog.md'--[[@as fspath]]) if not changelog then return end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 27ba6273..70a4ea92 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -253,7 +253,11 @@ local searchFieldSwitch = util.switch() end end) - +---@param suri uri +---@param object vm.global +---@param key string|vm.global +---@param ref boolean +---@param pushResult fun(field: vm.object, isMark?: boolean) function vm.getClassFields(suri, object, key, ref, pushResult) local mark = {} @@ -376,7 +380,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult) for _, set in ipairs(sets) do pushResult(set) end - else + elseif type(key) == 'string' then local global = vm.getGlobal('variable', key) if global then for _, set in ipairs(global:getSets(suri)) do @@ -677,7 +681,7 @@ local function bindAs(source) end ---@param source vm.node ----@param key? any +---@param key? string|vm.global ---@param pushResult fun(source: parser.object) function vm.compileByParentNode(source, key, ref, pushResult) local parentNode = vm.compileNode(source) @@ -1380,12 +1384,25 @@ local compilerSwitch = util.switch() return end if type(key) == 'table' then + ---@cast key vm.node local uri = guide.getUri(source) local value = vm.getTableValue(uri, vm.compileNode(source.node), key) if value then vm.setNode(source, value):removeOptional() end + for k in key:eachObject() do + if k.type == 'global' and k.cate == 'type' then + ---@cast k vm.global + vm.compileByParentNode(source.node, k, false, function (src) + vm.setNode(source, vm.compileNode(src)) + if src.value then + vm.setNode(source, vm.compileNode(src.value)):removeOptional() + end + end) + end + end else + ---@cast key string vm.compileByParentNode(source.node, key, false, function (src) vm.setNode(source, vm.compileNode(src)) if src.value then diff --git a/script/vm/node.lua b/script/vm/node.lua index 65a203f8..24f87a45 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -9,6 +9,7 @@ vm.nodeCache = {} ---@class vm.node ---@field [integer] vm.object +---@field [vm.object] true local mt = {} mt.__index = mt mt.id = 0 @@ -234,8 +235,7 @@ function mt:narrow(name) end for index = #self, 1, -1 do local c = self[index] - if (c.type == 'global' and c.cate == 'type' and c.name == name) - or (c.type == name) + if (c.type == name) or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer')) or (c.type == 'doc.type.boolean' and name == 'boolean') or (c.type == 'doc.type.table' and name == 'table') @@ -243,6 +243,12 @@ function mt:narrow(name) or (c.type == 'doc.type.function' and name == 'function') then goto CONTINUE end + if c.type == 'global' and c.cate == 'type' then + if (c.name == name) + or (c.name == 'integer' and name == 'number') then + goto CONTINUE + end + end table.remove(self, index) self[c] = nil ::CONTINUE:: @@ -268,6 +274,7 @@ end function mt:removeNode(node) for _, c in ipairs(node) do if c.type == 'global' and c.cate == 'type' then + ---@cast c vm.global self:remove(c.name) elseif c.type == 'nil' then self:remove 'nil' diff --git a/script/vm/type.lua b/script/vm/type.lua index 05ada9ea..982f937d 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -84,6 +84,9 @@ function vm.isSubType(uri, child, parent, mark) return false end + ---@cast child vm.object + ---@cast parent vm.object + local childName = getNodeName(child) local parentName = getNodeName(parent) if childName == 'any' @@ -146,7 +149,7 @@ end ---@param uri uri ---@param tnode vm.node ----@param knode vm.node +---@param knode vm.node|string ---@return vm.node? function vm.getTableValue(uri, tnode, knode) local result = vm.createNode() diff --git a/script/vm/value.lua b/script/vm/value.lua index 13e27e59..e6e2045d 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -72,9 +72,11 @@ local function getUnique(v) return ('num:%s'):format(v[1]) end if v.type == 'table' then + ---@cast v parser.object return ('table:%s@%d'):format(guide.getUri(v), v.start) end if v.type == 'function' then + ---@cast v parser.object return ('func:%s@%d'):format(guide.getUri(v), v.start) end return false diff --git a/test/diagnostics/common.lua b/test/diagnostics/common.lua index 0708f63a..a837fc0f 100644 --- a/test/diagnostics/common.lua +++ b/test/diagnostics/common.lua @@ -1659,3 +1659,19 @@ k(f()) TEST [[ ---@cast <!x!> integer ]] + +TEST [[ +---@class A + +---@class B +---@field [integer] A +---@field [A] true +]] + +TEST [[ +---@class A + +---@class B +---@field [A] A +---@field [<!A!>] true +]] diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index bca3a546..434307a2 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -3016,3 +3016,30 @@ if type(x) == 'number' then print(<?x?>) end ]] + +TEST 'boolean' [[ +---@class A +---@field [integer] boolean +local mt + +function mt:f() + ---@type integer + local index + local <?x?> = self[index] +end +]] + +TEST 'boolean' [[ +---@class A +---@field [B] boolean + +---@class B + +---@type A +local a + +---@type B +local b + +local <?x?> = a[b] +]] |