summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/diagnostics/assign-type-mismatch.lua12
-rw-r--r--script/core/semantic-tokens.lua6
-rw-r--r--script/file-uri.lua3
-rw-r--r--script/provider/diagnostic.lua11
-rw-r--r--script/vm/def.lua2
-rw-r--r--script/vm/function.lua1
-rw-r--r--script/vm/global.lua12
-rw-r--r--script/vm/infer.lua2
-rw-r--r--script/vm/node.lua4
-rw-r--r--script/vm/ref.lua2
-rw-r--r--script/vm/sign.lua29
-rw-r--r--script/vm/type.lua33
-rw-r--r--script/vm/value.lua8
-rw-r--r--script/workspace/workspace.lua2
-rw-r--r--test/diagnostics/type-check.lua41
15 files changed, 116 insertions, 52 deletions
diff --git a/script/core/diagnostics/assign-type-mismatch.lua b/script/core/diagnostics/assign-type-mismatch.lua
index ae4b3512..3953cbc4 100644
--- a/script/core/diagnostics/assign-type-mismatch.lua
+++ b/script/core/diagnostics/assign-type-mismatch.lua
@@ -38,17 +38,9 @@ return function (uri, callback)
end
end
local valueNode = vm.compileNode(value)
- if source.type == 'setindex'
- and vm.isSubType(uri, valueNode, 'nil') then
+ if source.type == 'setindex' then
-- boolean[1] = nil
- local tnode = vm.compileNode(source.node)
- for n in tnode:eachObject() do
- if n.type == 'doc.type.array'
- or n.type == 'doc.type.table'
- or n.type == 'table' then
- return
- end
- end
+ valueNode = valueNode:copy():removeOptional()
end
local varNode = vm.compileNode(source)
if vm.canCastType(uri, varNode, valueNode) then
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index 8be70198..5596245b 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -828,9 +828,13 @@ return function (uri, start, finish)
keyword = config.get(uri, 'Lua.semantic.keyword'),
}
+ local n = 0
guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async
Care(source.type, source, options, results)
- await.delay()
+ n = n + 1
+ if n % 100 == 0 then
+ await.delay()
+ end
end)
for _, comm in ipairs(state.comms) do
diff --git a/script/file-uri.lua b/script/file-uri.lua
index ccd47156..877d2a1c 100644
--- a/script/file-uri.lua
+++ b/script/file-uri.lua
@@ -24,9 +24,6 @@ local m = {}
---@param path string
---@return uri uri
function m.encode(path)
- if not path then
- return nil
- end
local authority = ''
if platform.OS == 'Windows' then
path = path:gsub('\\', '/')
diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua
index 90421c3e..713d4f48 100644
--- a/script/provider/diagnostic.lua
+++ b/script/provider/diagnostic.lua
@@ -275,12 +275,7 @@ function m.doDiagnostic(uri, isScopeDiag)
local diags = {}
local lastDiag = copyDiagsWithoutSyntax(m.cache[uri])
- local function pushResult(isComplete)
- -- Disable incremental diagnosis.
- -- The current diagnosis speed is good.
- if not isComplete then
- return
- end
+ local function pushResult()
tracy.ZoneBeginN 'mergeSyntaxAndDiags'
local _ <close> = tracy.ZoneEnd
local full = mergeDiags(syntax, lastDiag, diags)
@@ -310,7 +305,7 @@ function m.doDiagnostic(uri, isScopeDiag)
xpcall(core, log.error, uri, isScopeDiag, function (result)
diags[#diags+1] = buildDiagnostic(uri, result)
- if not isScopeDiag and time.time() - lastPushClock >= 200 then
+ if not isScopeDiag and time.time() - lastPushClock >= 1000 then
lastPushClock = time.time()
pushResult()
end
@@ -327,7 +322,7 @@ function m.doDiagnostic(uri, isScopeDiag)
end)
lastDiag = nil
- pushResult(true)
+ pushResult()
end
---@param uri uri
diff --git a/script/vm/def.lua b/script/vm/def.lua
index 03743826..a7af29b2 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -33,7 +33,7 @@ local searchFieldSwitch = util.switch()
end
end)
: case 'global'
- ---@param obj vm.object
+ ---@param obj vm.global
---@param key string
: call(function (suri, obj, key, pushResult)
if obj.cate == 'variable' then
diff --git a/script/vm/function.lua b/script/vm/function.lua
index e8fadb38..6c07acf6 100644
--- a/script/vm/function.lua
+++ b/script/vm/function.lua
@@ -62,6 +62,7 @@ function vm.countParamsOfNode(node)
for n in node:eachObject() do
if n.type == 'function'
or n.type == 'doc.type.function' then
+ ---@cast n parser.object
local fmin, fmax = vm.countParamsOfFunction(n)
if not min or fmin < min then
min = fmin
diff --git a/script/vm/global.lua b/script/vm/global.lua
index 50febd2c..b94e2768 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -162,6 +162,9 @@ local compilerGlobalSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
local global = vm.declareGlobal('variable', name, uri)
global:addSet(uri, source)
source._globalNode = global
@@ -170,6 +173,9 @@ local compilerGlobalSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
local global = vm.declareGlobal('variable', name, uri)
global:addGet(uri, source)
source._globalNode = global
@@ -272,6 +278,9 @@ local compilerGlobalSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
local class = vm.declareGlobal('type', name, uri)
class:addSet(uri, source)
source._globalNode = class
@@ -294,6 +303,9 @@ local compilerGlobalSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
+ if not name then
+ return
+ end
local alias = vm.declareGlobal('type', name, uri)
alias:addSet(uri, source)
source._globalNode = alias
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index e789214a..d6c4da44 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -443,7 +443,7 @@ function mt:viewClass()
return table.concat(class, '|')
end
----@param source parser.object
+---@param source vm.node.object
---@param uri uri
---@return string?
function vm.viewObject(source, uri)
diff --git a/script/vm/node.lua b/script/vm/node.lua
index fce8c642..2128edb2 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -11,7 +11,7 @@ vm.nodeCache = {}
---@class vm.node
---@field [integer] vm.node.object
----@field [vm.object] true
+---@field [vm.node.object] true
local mt = {}
mt.__index = mt
mt.id = 0
@@ -38,6 +38,7 @@ function mt:merge(node)
end
end
else
+ ---@cast node -vm.node
if not self[node] then
self[node] = true
self[#self+1] = node
@@ -287,6 +288,7 @@ function mt:removeNode(node)
self:remove 'false'
end
else
+ ---@cast c -vm.global
self:removeObject(c)
end
end
diff --git a/script/vm/ref.lua b/script/vm/ref.lua
index c8b98acf..0135d11f 100644
--- a/script/vm/ref.lua
+++ b/script/vm/ref.lua
@@ -206,7 +206,7 @@ end
---@async
---@param source parser.object
---@param pushResult fun(src: parser.object)
----@param fileNotify fun(uri: uri): boolean
+---@param fileNotify? fun(uri: uri): boolean
function searchByParentNode(source, pushResult, defMap, fileNotify)
nodeSwitch(source.type, source, pushResult, defMap, fileNotify)
end
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 0f5962fd..14d289eb 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -24,7 +24,7 @@ function mt:resolve(uri, args, removeGeneric)
end
local resolved = {}
- ---@param object parser.object
+ ---@param object vm.node.object
---@param node vm.node
local function resolve(object, node)
if object.type == 'doc.generic.name' then
@@ -33,6 +33,7 @@ function mt:resolve(uri, args, removeGeneric)
-- 'number' -> `T`
for n in node:eachObject() do
if n.type == 'string' then
+ ---@cast n parser.object
local type = vm.declareGlobal('type', n[1], guide.getUri(n))
resolved[key] = vm.createNode(type, resolved[key])
end
@@ -57,6 +58,7 @@ function mt:resolve(uri, args, removeGeneric)
end
if n.type == 'global' and n.cate == 'type' then
-- ---@field [integer]: number -> T[]
+ ---@cast n vm.global
vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), false, function (field)
resolve(object.node, vm.compileNode(field.extends))
end)
@@ -67,23 +69,37 @@ function mt:resolve(uri, args, removeGeneric)
for _, ufield in ipairs(object.fields) do
local ufieldNode = vm.compileNode(ufield.name)
local uvalueNode = vm.compileNode(ufield.extends)
- if ufieldNode:get(1).type == 'doc.generic.name' and uvalueNode:get(1).type == 'doc.generic.name' then
+ local firstField = ufieldNode:get(1)
+ local firstValue = uvalueNode:get(1)
+ if not firstField or not firstValue then
+ goto CONTINUE
+ end
+ if firstField.type == 'doc.generic.name' and firstValue.type == 'doc.generic.name' then
-- { [number]: number} -> { [K]: V }
local tfieldNode = vm.getTableKey(uri, node, 'any')
local tvalueNode = vm.getTableValue(uri, node, 'any')
- resolve(ufieldNode:get(1), tfieldNode)
- resolve(uvalueNode:get(1), tvalueNode)
+ if tfieldNode then
+ resolve(firstField, tfieldNode)
+ end
+ if tvalueNode then
+ resolve(firstValue, tvalueNode)
+ end
else
if ufieldNode:get(1).type == 'doc.generic.name' then
-- { [number]: number}|number[] -> { [K]: number }
local tnode = vm.getTableKey(uri, node, uvalueNode)
- resolve(ufieldNode:get(1), tnode)
+ if tnode then
+ resolve(firstField, tnode)
+ end
elseif uvalueNode:get(1).type == 'doc.generic.name' then
-- { [number]: number}|number[] -> { [number]: V }
local tnode = vm.getTableValue(uri, node, ufieldNode)
- resolve(uvalueNode:get(1), tnode)
+ if tnode then
+ resolve(firstValue, tnode)
+ end
end
end
+ ::CONTINUE::
end
end
end
@@ -102,6 +118,7 @@ function mt:resolve(uri, args, removeGeneric)
if obj.type == 'doc.type.table'
or obj.type == 'doc.type.function'
or obj.type == 'doc.type.array' then
+ ---@cast obj parser.object
local hasGeneric
guide.eachSourceType(obj, 'doc.generic.name', function (src)
hasGeneric = true
diff --git a/script/vm/type.lua b/script/vm/type.lua
index d8adac6e..5211c3cb 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -2,7 +2,7 @@
local vm = require 'vm.vm'
local guide = require 'parser.guide'
----@param object vm.object
+---@param object vm.node.object
---@return string?
local function getNodeName(object)
if object.type == 'global' and object.cate == 'type' then
@@ -39,18 +39,19 @@ local function getNodeName(object)
end
---@param uri uri
----@param child vm.node|string|vm.object
----@param parent vm.node|string|vm.object
+---@param child vm.node|string|vm.node.object
+---@param parent vm.node|string|vm.node.object
---@param mark? table
---@return boolean
function vm.isSubType(uri, child, parent, mark)
mark = mark or {}
if type(child) == 'string' then
- child = vm.getGlobal('type', child)
- if not child then
+ local global = vm.getGlobal('type', child)
+ if not global then
return false
end
+ child = global
elseif child.type == 'vm.node' then
for n in child:eachObject() do
if getNodeName(n)
@@ -67,10 +68,11 @@ function vm.isSubType(uri, child, parent, mark)
end
if type(parent) == 'string' then
- parent = vm.getGlobal('type', parent)
- if not parent then
+ local global = vm.getGlobal('type', parent)
+ if not global then
return false
end
+ parent = global
elseif parent.type == 'vm.node' then
for n in parent:eachObject() do
if getNodeName(n)
@@ -89,15 +91,17 @@ function vm.isSubType(uri, child, parent, mark)
return false
end
- ---@cast child vm.object
- ---@cast parent vm.object
+ ---@cast child vm.node.object
+ ---@cast parent vm.node.object
local childName = getNodeName(child)
local parentName = getNodeName(parent)
if childName == 'any'
or parentName == 'any'
or childName == 'unknown'
- or parentName == 'unknown' then
+ or parentName == 'unknown'
+ or not childName
+ or not parentName then
return true
end
@@ -140,13 +144,12 @@ function vm.isSubType(uri, child, parent, mark)
end
end
end
- if set.type == 'doc.alias' and set.extends then
- if vm.isSubType(uri, vm.compileNode(set.extends), parent, mark) then
- return true
- end
+ if set.type == 'doc.alias' then
+ return true
end
end
end
+ mark[childName] = nil
end
return false
@@ -204,7 +207,7 @@ end
---@param uri uri
---@param tnode vm.node
----@param vnode vm.node
+---@param vnode vm.node|string|vm.object
---@return vm.node?
function vm.getTableKey(uri, tnode, vnode)
local result = vm.createNode()
diff --git a/script/vm/value.lua b/script/vm/value.lua
index e6e2045d..0ebf5d08 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -50,7 +50,7 @@ function vm.test(source)
end
end
----@param v vm.object
+---@param v vm.node.object
---@return string?
local function getUnique(v)
if v.type == 'boolean' then
@@ -79,11 +79,11 @@ local function getUnique(v)
---@cast v parser.object
return ('func:%s@%d'):format(guide.getUri(v), v.start)
end
- return false
+ return nil
end
----@param a vm.object?
----@param b vm.object?
+---@param a parser.object?
+---@param b parser.object?
---@return boolean|nil
function vm.equal(a, b)
if not a or not b then
diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua
index 4c0011b3..3aebc5e0 100644
--- a/script/workspace/workspace.lua
+++ b/script/workspace/workspace.lua
@@ -48,7 +48,7 @@ end
function m.create(uri)
log.info('Workspace create: ', uri)
if uri == furi.encode '/'
- or uri == furi.encode(os.getenv 'HOME') then
+ or uri == furi.encode(os.getenv 'HOME' or '') then
client.showMessage('Error', lang.script('WORKSPACE_NOT_ALLOWED', furi.decode(uri)))
return
end
diff --git a/test/diagnostics/type-check.lua b/test/diagnostics/type-check.lua
index 147d6229..036456fe 100644
--- a/test/diagnostics/type-check.lua
+++ b/test/diagnostics/type-check.lua
@@ -396,5 +396,46 @@ a.x = XX
f(a.x)
]]
+TEST [[
+---@type string?
+local x
+
+local s = <!x!>:upper()
+]]
+
+TEST [[
+---@alias A string|boolean
+
+---@param x string|boolean
+local function f(x) end
+
+---@type A
+local x
+
+f(x)
+]]
+
+TEST [[
+---@alias A string|boolean
+
+---@param x A
+local function f(x) end
+
+---@type string|boolean
+local x
+
+f(x)
+]]
+
+TEST [[
+---@type boolean[]
+local t = {}
+
+---@type boolean?
+local x
+
+t[#t+1] = x
+]]
+
config.remove(nil, 'Lua.diagnostics.disable', 'unused-local')
config.remove(nil, 'Lua.diagnostics.disable', 'undefined-global')