diff options
-rw-r--r-- | script/core/diagnostics/assign-type-mismatch.lua | 12 | ||||
-rw-r--r-- | script/core/semantic-tokens.lua | 6 | ||||
-rw-r--r-- | script/file-uri.lua | 3 | ||||
-rw-r--r-- | script/provider/diagnostic.lua | 11 | ||||
-rw-r--r-- | script/vm/def.lua | 2 | ||||
-rw-r--r-- | script/vm/function.lua | 1 | ||||
-rw-r--r-- | script/vm/global.lua | 12 | ||||
-rw-r--r-- | script/vm/infer.lua | 2 | ||||
-rw-r--r-- | script/vm/node.lua | 4 | ||||
-rw-r--r-- | script/vm/ref.lua | 2 | ||||
-rw-r--r-- | script/vm/sign.lua | 29 | ||||
-rw-r--r-- | script/vm/type.lua | 33 | ||||
-rw-r--r-- | script/vm/value.lua | 8 | ||||
-rw-r--r-- | script/workspace/workspace.lua | 2 | ||||
-rw-r--r-- | test/diagnostics/type-check.lua | 41 |
15 files changed, 116 insertions, 52 deletions
diff --git a/script/core/diagnostics/assign-type-mismatch.lua b/script/core/diagnostics/assign-type-mismatch.lua index ae4b3512..3953cbc4 100644 --- a/script/core/diagnostics/assign-type-mismatch.lua +++ b/script/core/diagnostics/assign-type-mismatch.lua @@ -38,17 +38,9 @@ return function (uri, callback) end end local valueNode = vm.compileNode(value) - if source.type == 'setindex' - and vm.isSubType(uri, valueNode, 'nil') then + if source.type == 'setindex' then -- boolean[1] = nil - local tnode = vm.compileNode(source.node) - for n in tnode:eachObject() do - if n.type == 'doc.type.array' - or n.type == 'doc.type.table' - or n.type == 'table' then - return - end - end + valueNode = valueNode:copy():removeOptional() end local varNode = vm.compileNode(source) if vm.canCastType(uri, varNode, valueNode) then diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index 8be70198..5596245b 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -828,9 +828,13 @@ return function (uri, start, finish) keyword = config.get(uri, 'Lua.semantic.keyword'), } + local n = 0 guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async Care(source.type, source, options, results) - await.delay() + n = n + 1 + if n % 100 == 0 then + await.delay() + end end) for _, comm in ipairs(state.comms) do diff --git a/script/file-uri.lua b/script/file-uri.lua index ccd47156..877d2a1c 100644 --- a/script/file-uri.lua +++ b/script/file-uri.lua @@ -24,9 +24,6 @@ local m = {} ---@param path string ---@return uri uri function m.encode(path) - if not path then - return nil - end local authority = '' if platform.OS == 'Windows' then path = path:gsub('\\', '/') diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua index 90421c3e..713d4f48 100644 --- a/script/provider/diagnostic.lua +++ b/script/provider/diagnostic.lua @@ -275,12 +275,7 @@ function m.doDiagnostic(uri, isScopeDiag) local diags = {} local lastDiag = copyDiagsWithoutSyntax(m.cache[uri]) - local function pushResult(isComplete) - -- Disable incremental diagnosis. - -- The current diagnosis speed is good. - if not isComplete then - return - end + local function pushResult() tracy.ZoneBeginN 'mergeSyntaxAndDiags' local _ <close> = tracy.ZoneEnd local full = mergeDiags(syntax, lastDiag, diags) @@ -310,7 +305,7 @@ function m.doDiagnostic(uri, isScopeDiag) xpcall(core, log.error, uri, isScopeDiag, function (result) diags[#diags+1] = buildDiagnostic(uri, result) - if not isScopeDiag and time.time() - lastPushClock >= 200 then + if not isScopeDiag and time.time() - lastPushClock >= 1000 then lastPushClock = time.time() pushResult() end @@ -327,7 +322,7 @@ function m.doDiagnostic(uri, isScopeDiag) end) lastDiag = nil - pushResult(true) + pushResult() end ---@param uri uri diff --git a/script/vm/def.lua b/script/vm/def.lua index 03743826..a7af29b2 100644 --- a/script/vm/def.lua +++ b/script/vm/def.lua @@ -33,7 +33,7 @@ local searchFieldSwitch = util.switch() end end) : case 'global' - ---@param obj vm.object + ---@param obj vm.global ---@param key string : call(function (suri, obj, key, pushResult) if obj.cate == 'variable' then diff --git a/script/vm/function.lua b/script/vm/function.lua index e8fadb38..6c07acf6 100644 --- a/script/vm/function.lua +++ b/script/vm/function.lua @@ -62,6 +62,7 @@ function vm.countParamsOfNode(node) for n in node:eachObject() do if n.type == 'function' or n.type == 'doc.type.function' then + ---@cast n parser.object local fmin, fmax = vm.countParamsOfFunction(n) if not min or fmin < min then min = fmin diff --git a/script/vm/global.lua b/script/vm/global.lua index 50febd2c..b94e2768 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -162,6 +162,9 @@ local compilerGlobalSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) local name = guide.getKeyName(source) + if not name then + return + end local global = vm.declareGlobal('variable', name, uri) global:addSet(uri, source) source._globalNode = global @@ -170,6 +173,9 @@ local compilerGlobalSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) local name = guide.getKeyName(source) + if not name then + return + end local global = vm.declareGlobal('variable', name, uri) global:addGet(uri, source) source._globalNode = global @@ -272,6 +278,9 @@ local compilerGlobalSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) local name = guide.getKeyName(source) + if not name then + return + end local class = vm.declareGlobal('type', name, uri) class:addSet(uri, source) source._globalNode = class @@ -294,6 +303,9 @@ local compilerGlobalSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) local name = guide.getKeyName(source) + if not name then + return + end local alias = vm.declareGlobal('type', name, uri) alias:addSet(uri, source) source._globalNode = alias diff --git a/script/vm/infer.lua b/script/vm/infer.lua index e789214a..d6c4da44 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -443,7 +443,7 @@ function mt:viewClass() return table.concat(class, '|') end ----@param source parser.object +---@param source vm.node.object ---@param uri uri ---@return string? function vm.viewObject(source, uri) diff --git a/script/vm/node.lua b/script/vm/node.lua index fce8c642..2128edb2 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -11,7 +11,7 @@ vm.nodeCache = {} ---@class vm.node ---@field [integer] vm.node.object ----@field [vm.object] true +---@field [vm.node.object] true local mt = {} mt.__index = mt mt.id = 0 @@ -38,6 +38,7 @@ function mt:merge(node) end end else + ---@cast node -vm.node if not self[node] then self[node] = true self[#self+1] = node @@ -287,6 +288,7 @@ function mt:removeNode(node) self:remove 'false' end else + ---@cast c -vm.global self:removeObject(c) end end diff --git a/script/vm/ref.lua b/script/vm/ref.lua index c8b98acf..0135d11f 100644 --- a/script/vm/ref.lua +++ b/script/vm/ref.lua @@ -206,7 +206,7 @@ end ---@async ---@param source parser.object ---@param pushResult fun(src: parser.object) ----@param fileNotify fun(uri: uri): boolean +---@param fileNotify? fun(uri: uri): boolean function searchByParentNode(source, pushResult, defMap, fileNotify) nodeSwitch(source.type, source, pushResult, defMap, fileNotify) end diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 0f5962fd..14d289eb 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -24,7 +24,7 @@ function mt:resolve(uri, args, removeGeneric) end local resolved = {} - ---@param object parser.object + ---@param object vm.node.object ---@param node vm.node local function resolve(object, node) if object.type == 'doc.generic.name' then @@ -33,6 +33,7 @@ function mt:resolve(uri, args, removeGeneric) -- 'number' -> `T` for n in node:eachObject() do if n.type == 'string' then + ---@cast n parser.object local type = vm.declareGlobal('type', n[1], guide.getUri(n)) resolved[key] = vm.createNode(type, resolved[key]) end @@ -57,6 +58,7 @@ function mt:resolve(uri, args, removeGeneric) end if n.type == 'global' and n.cate == 'type' then -- ---@field [integer]: number -> T[] + ---@cast n vm.global vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), false, function (field) resolve(object.node, vm.compileNode(field.extends)) end) @@ -67,23 +69,37 @@ function mt:resolve(uri, args, removeGeneric) for _, ufield in ipairs(object.fields) do local ufieldNode = vm.compileNode(ufield.name) local uvalueNode = vm.compileNode(ufield.extends) - if ufieldNode:get(1).type == 'doc.generic.name' and uvalueNode:get(1).type == 'doc.generic.name' then + local firstField = ufieldNode:get(1) + local firstValue = uvalueNode:get(1) + if not firstField or not firstValue then + goto CONTINUE + end + if firstField.type == 'doc.generic.name' and firstValue.type == 'doc.generic.name' then -- { [number]: number} -> { [K]: V } local tfieldNode = vm.getTableKey(uri, node, 'any') local tvalueNode = vm.getTableValue(uri, node, 'any') - resolve(ufieldNode:get(1), tfieldNode) - resolve(uvalueNode:get(1), tvalueNode) + if tfieldNode then + resolve(firstField, tfieldNode) + end + if tvalueNode then + resolve(firstValue, tvalueNode) + end else if ufieldNode:get(1).type == 'doc.generic.name' then -- { [number]: number}|number[] -> { [K]: number } local tnode = vm.getTableKey(uri, node, uvalueNode) - resolve(ufieldNode:get(1), tnode) + if tnode then + resolve(firstField, tnode) + end elseif uvalueNode:get(1).type == 'doc.generic.name' then -- { [number]: number}|number[] -> { [number]: V } local tnode = vm.getTableValue(uri, node, ufieldNode) - resolve(uvalueNode:get(1), tnode) + if tnode then + resolve(firstValue, tnode) + end end end + ::CONTINUE:: end end end @@ -102,6 +118,7 @@ function mt:resolve(uri, args, removeGeneric) if obj.type == 'doc.type.table' or obj.type == 'doc.type.function' or obj.type == 'doc.type.array' then + ---@cast obj parser.object local hasGeneric guide.eachSourceType(obj, 'doc.generic.name', function (src) hasGeneric = true diff --git a/script/vm/type.lua b/script/vm/type.lua index d8adac6e..5211c3cb 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -2,7 +2,7 @@ local vm = require 'vm.vm' local guide = require 'parser.guide' ----@param object vm.object +---@param object vm.node.object ---@return string? local function getNodeName(object) if object.type == 'global' and object.cate == 'type' then @@ -39,18 +39,19 @@ local function getNodeName(object) end ---@param uri uri ----@param child vm.node|string|vm.object ----@param parent vm.node|string|vm.object +---@param child vm.node|string|vm.node.object +---@param parent vm.node|string|vm.node.object ---@param mark? table ---@return boolean function vm.isSubType(uri, child, parent, mark) mark = mark or {} if type(child) == 'string' then - child = vm.getGlobal('type', child) - if not child then + local global = vm.getGlobal('type', child) + if not global then return false end + child = global elseif child.type == 'vm.node' then for n in child:eachObject() do if getNodeName(n) @@ -67,10 +68,11 @@ function vm.isSubType(uri, child, parent, mark) end if type(parent) == 'string' then - parent = vm.getGlobal('type', parent) - if not parent then + local global = vm.getGlobal('type', parent) + if not global then return false end + parent = global elseif parent.type == 'vm.node' then for n in parent:eachObject() do if getNodeName(n) @@ -89,15 +91,17 @@ function vm.isSubType(uri, child, parent, mark) return false end - ---@cast child vm.object - ---@cast parent vm.object + ---@cast child vm.node.object + ---@cast parent vm.node.object local childName = getNodeName(child) local parentName = getNodeName(parent) if childName == 'any' or parentName == 'any' or childName == 'unknown' - or parentName == 'unknown' then + or parentName == 'unknown' + or not childName + or not parentName then return true end @@ -140,13 +144,12 @@ function vm.isSubType(uri, child, parent, mark) end end end - if set.type == 'doc.alias' and set.extends then - if vm.isSubType(uri, vm.compileNode(set.extends), parent, mark) then - return true - end + if set.type == 'doc.alias' then + return true end end end + mark[childName] = nil end return false @@ -204,7 +207,7 @@ end ---@param uri uri ---@param tnode vm.node ----@param vnode vm.node +---@param vnode vm.node|string|vm.object ---@return vm.node? function vm.getTableKey(uri, tnode, vnode) local result = vm.createNode() diff --git a/script/vm/value.lua b/script/vm/value.lua index e6e2045d..0ebf5d08 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -50,7 +50,7 @@ function vm.test(source) end end ----@param v vm.object +---@param v vm.node.object ---@return string? local function getUnique(v) if v.type == 'boolean' then @@ -79,11 +79,11 @@ local function getUnique(v) ---@cast v parser.object return ('func:%s@%d'):format(guide.getUri(v), v.start) end - return false + return nil end ----@param a vm.object? ----@param b vm.object? +---@param a parser.object? +---@param b parser.object? ---@return boolean|nil function vm.equal(a, b) if not a or not b then diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua index 4c0011b3..3aebc5e0 100644 --- a/script/workspace/workspace.lua +++ b/script/workspace/workspace.lua @@ -48,7 +48,7 @@ end function m.create(uri) log.info('Workspace create: ', uri) if uri == furi.encode '/' - or uri == furi.encode(os.getenv 'HOME') then + or uri == furi.encode(os.getenv 'HOME' or '') then client.showMessage('Error', lang.script('WORKSPACE_NOT_ALLOWED', furi.decode(uri))) return end diff --git a/test/diagnostics/type-check.lua b/test/diagnostics/type-check.lua index 147d6229..036456fe 100644 --- a/test/diagnostics/type-check.lua +++ b/test/diagnostics/type-check.lua @@ -396,5 +396,46 @@ a.x = XX f(a.x) ]] +TEST [[ +---@type string? +local x + +local s = <!x!>:upper() +]] + +TEST [[ +---@alias A string|boolean + +---@param x string|boolean +local function f(x) end + +---@type A +local x + +f(x) +]] + +TEST [[ +---@alias A string|boolean + +---@param x A +local function f(x) end + +---@type string|boolean +local x + +f(x) +]] + +TEST [[ +---@type boolean[] +local t = {} + +---@type boolean? +local x + +t[#t+1] = x +]] + config.remove(nil, 'Lua.diagnostics.disable', 'unused-local') config.remove(nil, 'Lua.diagnostics.disable', 'undefined-global') |