summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/diagnostics/assign-type-mismatch.lua2
-rw-r--r--script/core/diagnostics/cast-local-type.lua2
-rw-r--r--script/core/diagnostics/cast-type-mismatch.lua2
-rw-r--r--script/core/signature.lua2
-rw-r--r--script/parser/compile.lua27
-rw-r--r--script/parser/guide.lua3
-rw-r--r--script/vm/compiler.lua147
-rw-r--r--script/vm/infer.lua10
-rw-r--r--script/vm/init.lua2
-rw-r--r--script/vm/node.lua18
-rw-r--r--script/vm/runner.lua565
-rw-r--r--script/vm/tracer.lua541
-rw-r--r--test/tclient/tests/recursive-runner.lua3
-rw-r--r--test/type_inference/init.lua158
14 files changed, 793 insertions, 689 deletions
diff --git a/script/core/diagnostics/assign-type-mismatch.lua b/script/core/diagnostics/assign-type-mismatch.lua
index 566f4a27..8472e87c 100644
--- a/script/core/diagnostics/assign-type-mismatch.lua
+++ b/script/core/diagnostics/assign-type-mismatch.lua
@@ -61,7 +61,7 @@ return function (uri, callback)
await.delay()
if source.type == 'setlocal' then
local locNode = vm.compileNode(source.node)
- if not locNode:getData 'hasDefined' then
+ if not locNode.hasDefined then
return
end
end
diff --git a/script/core/diagnostics/cast-local-type.lua b/script/core/diagnostics/cast-local-type.lua
index 1b1c8432..1998b915 100644
--- a/script/core/diagnostics/cast-local-type.lua
+++ b/script/core/diagnostics/cast-local-type.lua
@@ -18,7 +18,7 @@ return function (uri, callback)
end
await.delay()
local locNode = vm.compileNode(loc)
- if not locNode:getData 'hasDefined' then
+ if not locNode.hasDefined then
return
end
for _, ref in ipairs(loc.ref) do
diff --git a/script/core/diagnostics/cast-type-mismatch.lua b/script/core/diagnostics/cast-type-mismatch.lua
index c0483459..b2d2bdf3 100644
--- a/script/core/diagnostics/cast-type-mismatch.lua
+++ b/script/core/diagnostics/cast-type-mismatch.lua
@@ -22,7 +22,7 @@ return function (uri, callback)
local loc = defs[1]
if loc then
local defNode = vm.compileNode(loc)
- if defNode:getData 'hasDefined' then
+ if defNode.hasDefined then
for _, cast in ipairs(doc.casts) do
if not cast.mode and cast.extends then
local refNode = vm.compileNode(cast.extends)
diff --git a/script/core/signature.lua b/script/core/signature.lua
index 3465fda2..63b0cd0d 100644
--- a/script/core/signature.lua
+++ b/script/core/signature.lua
@@ -134,7 +134,7 @@ local function makeSignatures(text, call, pos)
local signs = {}
local node = vm.compileNode(func)
---@type vm.node
- node = node:getData 'originNode' or node
+ node = node.originNode or node
local mark = {}
for src in node:eachObject() do
if (src.type == 'function' and not vm.isVarargFunctionWithOverloads(src))
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index b8040382..17b9b051 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -2232,6 +2232,7 @@ local function parseFunction(isLocal, isAction)
type = 'function',
start = funcLeft,
finish = funcRight,
+ bstart = funcRight,
keyword = {
[1] = funcLeft,
[2] = funcRight,
@@ -2262,6 +2263,7 @@ local function parseFunction(isLocal, isAction)
end
func.name = simple
func.finish = simple.finish
+ func.bstart = simple.finish
if not isAction then
simple.parent = func
pushError {
@@ -2302,6 +2304,7 @@ local function parseFunction(isLocal, isAction)
if Tokens[Index + 1] == ')' then
local parenRight = getPosition(Tokens[Index], 'right')
func.finish = parenRight
+ func.bstart = parenRight
if params then
params.finish = parenRight
end
@@ -2309,6 +2312,7 @@ local function parseFunction(isLocal, isAction)
skipSpace(true)
else
func.finish = lastRightPosition()
+ func.bstart = func.finish
if params then
params.finish = func.finish
end
@@ -2963,6 +2967,7 @@ local function parseDo()
type = 'do',
start = doLeft,
finish = doRight,
+ bstart = doRight,
keyword = {
[1] = doLeft,
[2] = doRight,
@@ -3145,6 +3150,7 @@ local function parseIfBlock(parent)
parent = parent,
start = ifLeft,
finish = ifRight,
+ bstart = ifRight,
keyword = {
[1] = ifLeft,
[2] = ifRight,
@@ -3155,7 +3161,8 @@ local function parseIfBlock(parent)
if filter then
ifblock.filter = filter
ifblock.finish = filter.finish
- filter.parent = ifblock
+ ifblock.bstart = ifblock.finish
+ filter.parent = ifblock
else
missExp()
end
@@ -3164,6 +3171,7 @@ local function parseIfBlock(parent)
if thenToken == 'then'
or thenToken == 'do' then
ifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right')
+ ifblock.bstart = ifblock.finish
ifblock.keyword[3] = getPosition(Tokens[Index], 'left')
ifblock.keyword[4] = ifblock.finish
if thenToken == 'do' then
@@ -3203,6 +3211,7 @@ local function parseElseIfBlock(parent)
parent = parent,
start = ifLeft,
finish = ifRight,
+ bstart = ifRight,
keyword = {
[1] = ifLeft,
[2] = ifRight,
@@ -3214,6 +3223,7 @@ local function parseElseIfBlock(parent)
if filter then
elseifblock.filter = filter
elseifblock.finish = filter.finish
+ elseifblock.bstart = elseifblock.finish
filter.parent = elseifblock
else
missExp()
@@ -3223,6 +3233,7 @@ local function parseElseIfBlock(parent)
if thenToken == 'then'
or thenToken == 'do' then
elseifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right')
+ elseifblock.bstart = elseifblock.finish
elseifblock.keyword[3] = getPosition(Tokens[Index], 'left')
elseifblock.keyword[4] = elseifblock.finish
if thenToken == 'do' then
@@ -3262,6 +3273,7 @@ local function parseElseBlock(parent)
parent = parent,
start = ifLeft,
finish = ifRight,
+ bstart = ifRight,
keyword = {
[1] = ifLeft,
[2] = ifRight,
@@ -3337,6 +3349,7 @@ local function parseFor()
finish = getPosition(Tokens[Index] + 2, 'right'),
keyword = {},
}
+ action.bstart = action.finish
action.keyword[1] = action.start
action.keyword[2] = action.finish
Index = Index + 2
@@ -3366,6 +3379,7 @@ local function parseFor()
local loc = createLocal(name)
loc.parent = action
action.finish = name.finish
+ action.bstart = action.finish
action.loc = loc
end
if expList then
@@ -3375,12 +3389,14 @@ local function parseFor()
value.parent = expList
action.init = value
action.finish = expList[#expList].finish
+ action.bstart = action.finish
end
local max = expList[2]
if max then
max.parent = expList
action.max = max
action.finish = max.finish
+ action.bstart = action.finish
else
pushError {
type = 'MISS_LOOP_MAX',
@@ -3393,6 +3409,7 @@ local function parseFor()
step.parent = expList
action.step = step
action.finish = step.finish
+ action.bstart = action.finish
end
else
pushError {
@@ -3414,7 +3431,8 @@ local function parseFor()
local exps = parseExpList()
- action.finish = inRight
+ action.finish = inRight
+ action.bstart = action.finish
action.keyword[3] = inLeft
action.keyword[4] = inRight
@@ -3435,6 +3453,7 @@ local function parseFor()
local lastExp = exps[#exps]
if lastExp then
action.finish = lastExp.finish
+ action.bstart = action.finish
end
action.exps = exps
@@ -3468,6 +3487,7 @@ local function parseFor()
local left = getPosition(Tokens[Index], 'left')
local right = getPosition(Tokens[Index] + #doToken - 1, 'right')
action.finish = left
+ action.bstart = action.finish
action.keyword[#action.keyword+1] = left
action.keyword[#action.keyword+1] = right
if doToken == 'then' then
@@ -3518,6 +3538,7 @@ local function parseWhile()
finish = getPosition(Tokens[Index] + 4, 'right'),
keyword = {},
}
+ action.bstart = action.finish
action.keyword[1] = action.start
action.keyword[2] = action.finish
Index = Index + 2
@@ -3542,6 +3563,7 @@ local function parseWhile()
local left = getPosition(Tokens[Index], 'left')
local right = getPosition(Tokens[Index] + #doToken - 1, 'right')
action.finish = left
+ action.bstart = action.finish
action.keyword[#action.keyword+1] = left
action.keyword[#action.keyword+1] = right
if doToken == 'then' then
@@ -3594,6 +3616,7 @@ local function parseRepeat()
finish = getPosition(Tokens[Index] + 5, 'right'),
keyword = {},
}
+ action.bstart = action.finish
action.keyword[1] = action.start
action.keyword[2] = action.finish
Index = Index + 2
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 147e6237..f27a2af7 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -21,6 +21,7 @@ local type = type
---@field finish integer
---@field range integer
---@field effect integer
+---@field bstart integer
---@field attrs string[]
---@field specials parser.object[]
---@field labels parser.object[]
@@ -139,7 +140,7 @@ local childMap = {
['getfield'] = {'node', 'field'},
['list'] = {'#'},
['binary'] = {1, 2},
- ['unary'] = {1},
+ ['unary'] = { 1 },
['doc'] = {'#'},
['doc.class'] = {'class', '#extends', '#signs', 'comment'},
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 168ad536..446c357e 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -538,7 +538,7 @@ local function matchCall(source)
if needRemove then
local newNode = myNode:copy()
newNode:removeNode(needRemove)
- newNode:setData('originNode', myNode)
+ newNode.originNode = myNode
vm.setNode(source, newNode, true)
end
end
@@ -836,10 +836,13 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
for i = fixIndex + 1, myIndex - 1 do
args[#args+1] = call.args[i]
end
- fn = generic:resolve(guide.getUri(call), args)
+ local resolvedNode = generic:resolve(guide.getUri(call), args)
+ vm.setNode(arg, resolvedNode)
+ goto CONTINUE
end
end
vm.setNode(arg, fn)
+ ::CONTINUE::
end
end
end
@@ -907,9 +910,10 @@ end
---@param source parser.object
---@param target parser.object
+---@return boolean
local function compileForVars(source, target)
if not source.exps then
- return
+ return false
end
-- for k, v in pairs(t) do
--> for k, v in iterator, status, initValue do
@@ -940,9 +944,11 @@ local function compileForVars(source, target)
local node = getReturn(source._iterator, i, source._iterArgs)
node:removeOptional()
vm.setNode(loc, node)
+ return true
end
end
end
+ return false
end
---@param source parser.object
@@ -972,17 +978,6 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(source.value))
end
end
- if not hasMarkValue and not hasMarkValue then
- if source.ref then
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal'
- and ref.value
- and ref.value.type == 'function' then
- vm.setNode(source, vm.compileNode(ref.value))
- end
- end
- end
- end
-- function x.y(self, ...) --> function x:y(...)
if source[1] == 'self'
and not hasMarkDoc
@@ -1021,6 +1016,7 @@ local function compileLocal(source)
-- for x in ... do
if source.parent.type == 'in' then
compileForVars(source.parent, source)
+ hasMarkDoc = true
end
-- for x = ... do
@@ -1030,10 +1026,33 @@ local function compileLocal(source)
return
end
vm.setNode(source, vm.declareGlobal('type', 'integer'))
+ hasMarkDoc = true
end
end
- myNode:setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue)
+ if not hasMarkDoc
+ and not hasMarkValue
+ and source.ref then
+ local firstSet
+ local myFunction = guide.getParentFunction(source)
+ for _, ref in ipairs(source.ref) do
+ if ref.type == 'setlocal' then
+ firstSet = ref
+ break
+ end
+ if ref.type == 'getlocal' then
+ if guide.getParentFunction(ref) == myFunction then
+ break
+ end
+ end
+ end
+ if firstSet
+ and guide.getBlock(firstSet) == guide.getBlock(source) then
+ vm.setNode(source, vm.compileNode(firstSet))
+ end
+ end
+
+ myNode.hasDefined = hasMarkDoc or hasMarkParam or hasMarkValue
end
---@param source parser.object
@@ -1163,75 +1182,27 @@ local compilerSwitch = util.switch()
---@async
---@param source parser.object
: call(function (source)
- vm.launchRunner(source, function ()
- local myNode = vm.getNode(source)
- ---@cast myNode -?
- myNode:setData('resolving', true)
-
- if source.ref then
- for _, ref in ipairs(source.ref) do
- if ref.type == 'getlocal'
- or ref.type == 'setlocal' then
- vm.setNode(ref, myNode, true)
- end
- end
- end
- compileLocal(source)
-
- myNode.resolved = true
- end, function ()
- local myNode = vm.getNode(source)
- ---@cast myNode -?
- myNode:setData('resolving', nil)
- local hasMark = vm.getNode(source):getData 'hasDefined'
- if source.ref and not hasMark then
- local parentFunc = guide.getParentFunction(source)
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal'
- and guide.getParentFunction(ref) == parentFunc then
- local refNode = vm.getNode(ref)
- if refNode then
- vm.setNode(source, refNode)
- end
- end
- end
- end
- end, function (src, node)
- if src.type == 'setlocal' then
- if src.bindDocs then
- for _, doc in ipairs(src.bindDocs) do
- if doc.type == 'doc.type' then
- vm.setNode(src, vm.compileNode(doc), true)
- return vm.getNode(src)
- end
- end
- end
- if src.value then
- if src.value.type == 'table' then
- vm.setNode(src, vm.createNode(src.value), true)
- vm.setNode(src, node:copy():asTable())
- else
- vm.setNode(src, vm.compileNode(src.value), true)
- end
- else
- vm.setNode(src, node, true)
- end
- return vm.getNode(src)
- elseif src.type == 'getlocal' then
- if bindAs(src) then
- return
- end
- vm.setNode(src, node, true)
- node.resolved = true
- matchCall(src)
- end
- end)
-
- vm.waitResolveRunner(source)
+ compileLocal(source)
end)
: case 'setlocal'
: call(function (source)
- vm.compileNode(source.node)
+ if bindDocs(source) then
+ return
+ end
+ local locNode = vm.compileNode(source.node)
+ if not source.value then
+ vm.setNode(source, locNode)
+ return
+ end
+ local valueNode = vm.compileNode(source.value)
+ vm.setNode(source, valueNode)
+ if locNode.hasDefined
+ and guide.isLiteral(source.value) then
+ vm.setNode(source, locNode)
+ vm.getNode(source):narrow(guide.getUri(source), source.value.type)
+ else
+ vm.setNode(source, valueNode)
+ end
end)
: case 'getlocal'
---@async
@@ -1239,8 +1210,11 @@ local compilerSwitch = util.switch()
if bindAs(source) then
return
end
- vm.compileNode(source.node)
- vm.waitResolveRunner(source)
+ local node = vm.traceNode(source)
+ if not node then
+ return
+ end
+ vm.setNode(source, node, true)
end)
: case 'setfield'
: case 'setmethod'
@@ -1921,13 +1895,6 @@ function vm.compileNode(source)
end
end
- if source.type == 'getlocal' then
- ---@cast source parser.object
- vm.storeWaitingRunner(source)
- ---@diagnostic disable-next-line: await-in-sync
- vm.waitResolveRunner(source)
- end
-
local cache = vm.getNode(source)
if cache ~= nil then
return cache
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index b9dfb29a..99cf622e 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -432,10 +432,14 @@ function mt:view(uri, default)
end
if self.node:isOptional() then
- if max > 1 then
- view = '(' .. view .. ')?'
+ if #array == 0 then
+ view = 'nil'
else
- view = view .. '?'
+ if max > 1 then
+ view = '(' .. view .. ')?'
+ else
+ view = view .. '?'
+ end
end
end
diff --git a/script/vm/init.lua b/script/vm/init.lua
index 7b69a7eb..9c8ebe55 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -11,7 +11,7 @@ require 'vm.field'
require 'vm.doc'
require 'vm.type'
require 'vm.library'
-require 'vm.runner'
+require 'vm.tracer'
require 'vm.infer'
require 'vm.generic'
require 'vm.sign'
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 2e408128..65d752df 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -20,7 +20,8 @@ mt.id = 0
mt.type = 'vm.node'
mt.optional = nil
mt.data = nil
-mt.resolved = nil
+mt.hasDefined = nil
+mt.originNode = nil
---@param node vm.node | vm.node.object
---@return vm.node
@@ -70,21 +71,6 @@ function mt:get(n)
return self[n]
end
-function mt:setData(k, v)
- if not self.data then
- self.data = {}
- end
- self.data[k] = v
-end
-
----@return any
-function mt:getData(k)
- if not self.data then
- return nil
- end
- return self.data[k]
-end
-
function mt:addOptional()
self.optional = true
end
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
deleted file mode 100644
index 8e264521..00000000
--- a/script/vm/runner.lua
+++ /dev/null
@@ -1,565 +0,0 @@
----@class vm
-local vm = require 'vm.vm'
-local guide = require 'parser.guide'
-local linked = require 'linked-table'
-
----@alias vm.runner.callback fun(src: parser.object, node?: vm.node)
-
----@class vm.runner
----@field _loc parser.object
----@field _casts parser.object[]
----@field _callback vm.runner.callback
----@field _mark table
----@field _has table<parser.object, true>
----@field _main parser.object
----@field _uri uri
-local mt = {}
-mt.__index = mt
-mt._index = 1
-
----@return parser.object[]
-function mt:_getCasts()
- local root = guide.getRoot(self._loc)
- if not root._casts then
- root._casts = {}
- local docs = root.docs
- for _, doc in ipairs(docs) do
- if doc.type == 'doc.cast' and doc.loc then
- root._casts[#root._casts+1] = doc
- end
- end
- end
- return root._casts
-end
-
----@param obj parser.object
-function mt:_markHas(obj)
- while true do
- if self._has[obj] then
- return
- end
- self._has[obj] = true
- if obj == self._main then
- return
- end
- obj = obj.parent
- end
-end
-
-function mt:collect()
- local startPos = self._loc.start
- local finishPos = 0
-
- for _, ref in ipairs(self._loc.ref) do
- if ref.type == 'getlocal'
- or ref.type == 'setlocal' then
- self:_markHas(ref)
- if ref.finish > finishPos then
- finishPos = ref.finish
- end
- end
- end
-
- local casts = self:_getCasts()
- for _, cast in ipairs(casts) do
- if cast.loc[1] == self._loc[1]
- and cast.start > startPos
- and cast.finish < finishPos
- and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then
- self._casts[#self._casts+1] = cast
- end
- end
-end
-
----@param pos integer
----@param topNode vm.node
----@return vm.node
-function mt:_fastWardCasts(pos, topNode)
- for i = self._index, #self._casts do
- local action = self._casts[i]
- if action.start > pos then
- self._index = i
- return topNode
- end
- topNode = topNode:copy()
- for _, cast in ipairs(action.casts) do
- if cast.mode == '+' then
- if cast.optional then
- topNode:addOptional()
- end
- if cast.extends then
- topNode:merge(vm.compileNode(cast.extends))
- end
- elseif cast.mode == '-' then
- if cast.optional then
- topNode:removeOptional()
- end
- if cast.extends then
- topNode:removeNode(vm.compileNode(cast.extends))
- end
- else
- if cast.extends then
- topNode:clear()
- topNode:merge(vm.compileNode(cast.extends))
- end
- end
- end
- end
- self._index = self._index + 1
- return topNode
-end
-
----@param action parser.object
----@param topNode vm.node
----@param outNode? vm.node
----@return vm.node topNode
----@return vm.node outNode
-function mt:_lookIntoChild(action, topNode, outNode)
- if not self._has[action]
- or self._mark[action] then
- return topNode, topNode or outNode
- end
- self._mark[action] = true
- topNode = self:_fastWardCasts(action.start, topNode)
- if action.type == 'getlocal' then
- if action.node == self._loc then
- self._callback(action, topNode)
- if outNode then
- topNode = topNode:copy():setTruthy()
- outNode = outNode:copy():setFalsy()
- end
- end
- elseif action.type == 'function' then
- self:lookIntoBlock(action, topNode:copy())
- elseif action.type == 'unary' then
- if not action[1] then
- goto RETURN
- end
- if action.op.type == 'not' then
- outNode = outNode or topNode:copy()
- outNode, topNode = self:_lookIntoChild(action[1], topNode, outNode)
- outNode = outNode:copy()
- end
- elseif action.type == 'binary' then
- if not action[1] or not action[2] then
- goto RETURN
- end
- if action.op.type == 'and' then
- topNode = self:_lookIntoChild(action[1], topNode, topNode:copy())
- topNode = self:_lookIntoChild(action[2], topNode, topNode:copy())
- elseif action.op.type == 'or' then
- outNode = outNode or topNode:copy()
- local topNode1, outNode1 = self:_lookIntoChild(action[1], topNode, outNode)
- local topNode2, outNode2 = self:_lookIntoChild(action[2], outNode1, outNode1:copy())
- topNode = vm.createNode(topNode1, topNode2)
- outNode = outNode2:copy()
- elseif action.op.type == '=='
- or action.op.type == '~=' then
- local handler, checker
- for i = 1, 2 do
- if guide.isLiteral(action[i]) then
- checker = action[i]
- handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
- end
- end
- if not handler then
- goto RETURN
- end
- if handler.type == 'getlocal'
- and handler.node == self._loc then
- -- if x == y then
- topNode = self:_lookIntoChild(handler, topNode, outNode)
- local checkerNode = vm.compileNode(checker)
- local checkerName = vm.getNodeName(checker)
- if checkerName then
- topNode = topNode:copy()
- if action.op.type == '==' then
- topNode:narrow(self._uri, checkerName)
- if outNode then
- outNode:removeNode(checkerNode)
- end
- else
- topNode:removeNode(checkerNode)
- if outNode then
- outNode:narrow(self._uri, checkerName)
- end
- end
- end
- elseif handler.type == 'call'
- and checker.type == 'string'
- and handler.node.special == 'type'
- and handler.args
- and handler.args[1]
- and handler.args[1].type == 'getlocal'
- and handler.args[1].node == self._loc then
- -- if type(x) == 'string' then
- self:_lookIntoChild(handler, topNode:copy())
- if action.op.type == '==' then
- topNode:narrow(self._uri, checker[1])
- if outNode then
- outNode:remove(checker[1])
- end
- else
- topNode:remove(checker[1])
- if outNode then
- outNode:narrow(self._uri, checker[1])
- end
- end
- elseif handler.type == 'getlocal'
- and checker.type == 'string' then
- local nodeValue = vm.getObjectValue(handler.node)
- if nodeValue
- and nodeValue.type == 'select'
- and nodeValue.sindex == 1 then
- local call = nodeValue.vararg
- if call
- and call.type == 'call'
- and call.node.special == 'type'
- and call.args
- and call.args[1]
- and call.args[1].type == 'getlocal'
- and call.args[1].node == self._loc then
- -- `local tp = type(x);if tp == 'string' then`
- if action.op.type == '==' then
- topNode:narrow(self._uri, checker[1])
- if outNode then
- outNode:remove(checker[1])
- end
- else
- topNode:remove(checker[1])
- if outNode then
- outNode:narrow(self._uri, checker[1])
- end
- end
- end
- end
- end
- end
- elseif action.type == 'loop'
- or action.type == 'in'
- or action.type == 'repeat'
- or action.type == 'for' then
- topNode = self:lookIntoBlock(action, topNode:copy())
- elseif action.type == 'while' then
- local blockNode, mainNode
- if action.filter then
- blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy())
- else
- blockNode = topNode:copy()
- mainNode = topNode:copy()
- end
- blockNode = self:lookIntoBlock(action, blockNode:copy())
- topNode = mainNode:merge(blockNode)
- if action.filter then
- -- look into filter again
- guide.eachSource(action.filter, function (src)
- self._mark[src] = nil
- end)
- blockNode, topNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy())
- end
- elseif action.type == 'if' then
- local hasElse
- local mainNode = topNode:copy()
- local blockNodes = {}
- for _, subBlock in ipairs(action) do
- local blockNode = mainNode:copy()
- if subBlock.filter then
- blockNode, mainNode = self:_lookIntoChild(subBlock.filter, blockNode, mainNode)
- else
- hasElse = true
- mainNode:clear()
- end
- blockNode = self:lookIntoBlock(subBlock, blockNode:copy())
- local neverReturn = subBlock.hasReturn
- or subBlock.hasGoTo
- or subBlock.hasBreak
- or subBlock.hasError
- if not neverReturn then
- blockNodes[#blockNodes+1] = blockNode
- end
- end
- if not hasElse and not topNode:hasKnownType() then
- mainNode:merge(vm.declareGlobal('type', 'unknown'))
- end
- for _, blockNode in ipairs(blockNodes) do
- mainNode:merge(blockNode)
- end
- topNode = mainNode
- elseif action.type == 'call' then
- if action.node.special == 'assert' and action.args and action.args[1] then
- topNode = self:_lookIntoChild(action.args[1], topNode, topNode:copy())
- end
- elseif action.type == 'paren' then
- topNode, outNode = self:_lookIntoChild(action.exp, topNode, outNode)
- elseif action.type == 'setlocal' then
- if action.node == self._loc then
- if action.value then
- self:_lookIntoChild(action.value, topNode)
- end
- topNode = self._callback(action, topNode)
- end
- elseif action.type == 'local' then
- if action.value
- and action.ref
- and action.value.type == 'select' then
- local index = action.value.sindex
- local call = action.value.vararg
- if index == 1
- and call.type == 'call'
- and call.node
- and call.node.special == 'type'
- and call.args then
- local getLoc = call.args[1]
- if getLoc
- and getLoc.type == 'getlocal'
- and getLoc.node == self._loc then
- for _, ref in ipairs(action.ref) do
- self:_markHas(ref)
- end
- end
- end
- end
- end
- ::RETURN::
- guide.eachChild(action, function (src)
- if self._has[src] then
- self:_lookIntoChild(src, topNode)
- end
- end)
- return topNode, outNode or topNode
-end
-
----@param block parser.object
----@param topNode vm.node
----@return vm.node topNode
-function mt:lookIntoBlock(block, topNode)
- if not self._has[block] then
- return topNode
- end
- for _, action in ipairs(block) do
- if self._has[action] then
- topNode = self:_lookIntoChild(action, topNode)
- end
- end
- topNode = self:_fastWardCasts(block.finish, topNode)
- return topNode
-end
-
----@alias runner.info { target?: parser.object, loc: parser.object }
-
----@type thread?
-local masterRunner = nil
----@type table<thread, runner.info>
-local runnerInfo = setmetatable({}, {
- __mode = 'k',
- __index = function (self, k)
- self[k] = {}
- return self[k]
- end
-})
----@type linked-table?
-local runnerList = nil
-
----@async
----@param info runner.info
-local function waitResolve(info)
- while true do
- if not info.target then
- break
- end
- if info.target.node == info.loc then
- break
- end
- local node = vm.getNode(info.target)
- if node and node.resolved then
- break
- end
- coroutine.yield()
- end
- info.target = nil
-end
-
-local function resolveDeadLock()
- if not runnerList then
- return
- end
-
- ---@type runner.info[]
- local infos = {}
- for runner in runnerList:pairs() do
- local info = runnerInfo[runner]
- infos[#infos+1] = info
- end
-
- table.sort(infos, function (a, b)
- local uriA = guide.getUri(a.loc)
- local uriB = guide.getUri(b.loc)
- if uriA ~= uriB then
- return uriA < uriB
- end
- return a.loc.start < b.loc.start
- end)
-
- local firstTarget = infos[1].target
- ---@cast firstTarget -?
- local firstNode = vm.setNode(firstTarget, vm.getNode(firstTarget):copy(), true)
- firstNode.resolved = true
- firstNode:setData('resolvedByDeadLock', true)
-end
-
----@async
----@param loc parser.object
----@param start fun()
----@param finish fun()
----@param callback vm.runner.callback
-function vm.launchRunner(loc, start, finish, callback)
- local locNode = vm.getNode(loc)
- if not locNode then
- return
- end
-
- local function resumeMaster()
- for i = 1, 10010 do
- if not runnerList or runnerList:getSize() == 0 then
- return
- end
- local deadLock = true
- for runner in runnerList:pairs() do
- local info = runnerInfo[runner]
- local waitingSource = info.target
- if coroutine.status(runner) == 'suspended' then
- local suc, err = coroutine.resume(runner)
- if not suc then
- log.error(debug.traceback(runner, err))
- end
- else
- runnerList:pop(runner)
- deadLock = false
- end
- if not waitingSource or waitingSource ~= info.target then
- deadLock = false
- end
- end
- if runnerList:getSize() == 0 then
- return
- end
- if deadLock then
- resolveDeadLock()
- end
- if i == 10000 then
- local lines = {}
- lines[#lines+1] = 'Dead lock:'
- for runner in runnerList:pairs() do
- local info = runnerInfo[runner]
- lines[#lines+1] = '==============='
- lines[#lines+1] = string.format('Runner `%s` at %d(%s)'
- , info.loc[1]
- , info.loc.start
- , guide.getUri(info.loc)
- )
- lines[#lines+1] = string.format('Waiting `%s` at %d(%s)'
- , info.target[1]
- , info.target.start
- , guide.getUri(info.target)
- )
- end
- local msg = table.concat(lines, '\n')
- log.error(msg)
- end
- end
- end
-
- local function launch()
- start()
- if not loc.ref then
- finish()
- return
- end
- local main = guide.getParentBlock(loc)
- if not main then
- finish()
- return
- end
- local self = setmetatable({
- _loc = loc,
- _casts = {},
- _mark = {},
- _has = {},
- _main = main,
- _uri = guide.getUri(loc),
- _callback = callback,
- }, mt)
-
- self:collect()
-
- self:lookIntoBlock(main, locNode:copy())
-
- locNode:setData('runner', nil)
-
- finish()
- end
-
- local co = coroutine.create(launch)
- locNode:setData('runner', co)
- local info = runnerInfo[co]
- info.loc = loc
-
- if not runnerList then
- runnerList = linked()
- end
- runnerList:pushTail(co)
-
- if not masterRunner then
- masterRunner = coroutine.running()
- resumeMaster()
- masterRunner = nil
- return
- end
-end
-
----@async
----@param source parser.object
-function vm.waitResolveRunner(source)
- local myNode = vm.getNode(source)
- if myNode and myNode.resolved then
- return
- end
-
- local running = coroutine.running()
- if not masterRunner or running == masterRunner then
- return
- end
-
- local info = runnerInfo[running]
-
- local targetLoc
- if source.type == 'getlocal' then
- targetLoc = source.node
- elseif source.type == 'local'
- or source.type == 'self' then
- targetLoc = source
- info.target = info.target or source
- else
- error('Unknown source type: ' .. source.type)
- end
-
- local targetNode = vm.getNode(targetLoc)
- if not targetNode then
- -- Wait for compiling local by `compiler`
- return
- end
-
- waitResolve(info)
-end
-
----@param source parser.object
-function vm.storeWaitingRunner(source)
- local sourceNode = vm.getNode(source)
- if sourceNode and sourceNode.resolved then
- return
- end
-
- local running = coroutine.running()
- local info = runnerInfo[running]
- info.target = source
-end
diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua
new file mode 100644
index 00000000..21a2619f
--- /dev/null
+++ b/script/vm/tracer.lua
@@ -0,0 +1,541 @@
+---@class vm
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local util = require 'utility'
+
+---@class parser.object
+---@field package _tracer? vm.tracer
+---@field package _casts? parser.object[]
+
+---@class vm.tracer
+---@field source parser.object
+---@field assigns parser.object[]
+---@field assignMap table<parser.object, true>
+---@field careMap table<parser.object, true>
+---@field mark table<parser.object, true>
+---@field casts parser.object[]
+---@field nodes table<parser.object, vm.node|false>
+---@field main parser.object
+---@field uri uri
+---@field castIndex integer?
+local mt = {}
+mt.__index = mt
+
+---@return parser.object[]
+function mt:getCasts()
+ local root = guide.getRoot(self.source)
+ if not root._casts then
+ root._casts = {}
+ local docs = root.docs
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.cast' and doc.loc then
+ root._casts[#root._casts+1] = doc
+ end
+ end
+ end
+ return root._casts
+end
+
+---@param obj parser.object
+function mt:collectAssign(obj)
+ while true do
+ local block = guide.getParentBlock(obj)
+ if not block then
+ return
+ end
+ obj = block
+ if self.assignMap[obj] then
+ return
+ end
+ if obj == self.main then
+ return
+ end
+ self.assignMap[obj] = true
+ self.assigns[#self.assigns+1] = obj
+ end
+end
+
+---@param obj parser.object
+function mt:collectCare(obj)
+ while true do
+ if self.careMap[obj] then
+ return
+ end
+ if obj == self.main then
+ return
+ end
+ self.careMap[obj] = true
+ obj = obj.parent
+ end
+end
+
+function mt:collectLocal()
+ local startPos = self.source.start
+ local finishPos = 0
+
+ self.assigns[#self.assigns+1] = self.source
+ self.assignMap[self.source] = true
+
+ for _, obj in ipairs(self.source.ref) do
+ if obj.type == 'setlocal' then
+ self.assigns[#self.assigns+1] = obj
+ self.assignMap[obj] = true
+ self:collectCare(obj)
+ if obj.finish > finishPos then
+ finishPos = obj.finish
+ end
+ end
+ if obj.type == 'getlocal' then
+ self:collectCare(obj)
+ if obj.finish > finishPos then
+ finishPos = obj.finish
+ end
+ end
+ end
+
+ local casts = self:getCasts()
+ for _, cast in ipairs(casts) do
+ if cast.loc[1] == self.source[1]
+ and cast.start > startPos
+ and cast.finish < finishPos
+ and guide.getLocal(self.source, self.source[1], cast.start) == self.source then
+ self.casts[#self.casts+1] = cast
+ end
+ end
+end
+
+---@param start integer
+---@param finish integer
+---@return parser.object?
+function mt:getLastAssign(start, finish)
+ local assign
+ for _, obj in ipairs(self.assigns) do
+ if obj.start < start then
+ goto CONTINUE
+ end
+ if (obj.range or obj.start) >= finish then
+ break
+ end
+ local objBlock = guide.getParentBlock(obj)
+ if not objBlock then
+ break
+ end
+ if objBlock.start <= finish
+ and objBlock.finish >= finish then
+ assign = obj
+ end
+ ::CONTINUE::
+ end
+ return assign
+end
+
+---@param pos integer
+function mt:resetCastsIndex(pos)
+ for i = 1, #self.casts do
+ local cast = self.casts[i]
+ if cast.start > pos then
+ self.castIndex = i
+ return
+ end
+ end
+ self.castIndex = nil
+end
+
+---@param pos integer
+---@param node vm.node
+---@return vm.node
+function mt:fastWardCasts(pos, node)
+ if not self.castIndex then
+ return node
+ end
+ for i = self.castIndex, #self.casts do
+ local action = self.casts[i]
+ if action.start > pos then
+ return node
+ end
+ node = node:copy()
+ for _, cast in ipairs(action.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ node:addOptional()
+ end
+ if cast.extends then
+ node:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ node:removeOptional()
+ end
+ if cast.extends then
+ node:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ node:clear()
+ node:merge(vm.compileNode(cast.extends))
+ end
+ end
+ end
+ end
+ self.castIndex = self.castIndex + 1
+ return node
+end
+
+---@param action parser.object
+---@param topNode vm.node
+---@param outNode? vm.node
+---@return vm.node topNode
+---@return vm.node outNode
+function mt:lookIntoChild(action, topNode, outNode)
+ if not self.careMap[action]
+ or self.mark[action] then
+ return topNode, outNode or topNode
+ end
+ self.mark[action] = true
+ topNode = self:fastWardCasts(action.start, topNode)
+ if action.type == 'getlocal' then
+ if action.node == self.source then
+ self.nodes[action] = topNode
+ if outNode then
+ topNode = topNode:copy():setTruthy()
+ outNode = outNode:copy():setFalsy()
+ end
+ end
+ elseif action.type == 'function' then
+ self:lookIntoBlock(action, action.args.finish, topNode:copy())
+ elseif action.type == 'unary' then
+ if not action[1] then
+ goto RETURN
+ end
+ if action.op.type == 'not' then
+ outNode = outNode or topNode:copy()
+ outNode, topNode = self:lookIntoChild(action[1], topNode, outNode)
+ outNode = outNode:copy()
+ end
+ elseif action.type == 'binary' then
+ if not action[1] or not action[2] then
+ goto RETURN
+ end
+ if action.op.type == 'and' then
+ topNode = self:lookIntoChild(action[1], topNode, topNode:copy())
+ topNode = self:lookIntoChild(action[2], topNode, topNode:copy())
+ elseif action.op.type == 'or' then
+ outNode = outNode or topNode:copy()
+ local topNode1, outNode1 = self:lookIntoChild(action[1], topNode, outNode)
+ local topNode2, outNode2 = self:lookIntoChild(action[2], outNode1, outNode1:copy())
+ topNode = vm.createNode(topNode1, topNode2)
+ outNode = outNode2:copy()
+ elseif action.op.type == '=='
+ or action.op.type == '~=' then
+ local handler, checker
+ for i = 1, 2 do
+ if guide.isLiteral(action[i]) then
+ checker = action[i]
+ handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
+ end
+ end
+ if not handler then
+ goto RETURN
+ end
+ if handler.type == 'getlocal'
+ and handler.node == self.source then
+ -- if x == y then
+ topNode = self:lookIntoChild(handler, topNode, outNode)
+ local checkerNode = vm.compileNode(checker)
+ local checkerName = vm.getNodeName(checker)
+ if checkerName then
+ topNode = topNode:copy()
+ if action.op.type == '==' then
+ topNode:narrow(self.uri, checkerName)
+ if outNode then
+ outNode:removeNode(checkerNode)
+ end
+ else
+ topNode:removeNode(checkerNode)
+ if outNode then
+ outNode:narrow(self.uri, checkerName)
+ end
+ end
+ end
+ elseif handler.type == 'call'
+ and checker.type == 'string'
+ and handler.node.special == 'type'
+ and handler.args
+ and handler.args[1]
+ and handler.args[1].type == 'getlocal'
+ and handler.args[1].node == self.source then
+ -- if type(x) == 'string' then
+ self:lookIntoChild(handler, topNode:copy())
+ if action.op.type == '==' then
+ topNode:narrow(self.uri, checker[1])
+ if outNode then
+ outNode:remove(checker[1])
+ end
+ else
+ topNode:remove(checker[1])
+ if outNode then
+ outNode:narrow(self.uri, checker[1])
+ end
+ end
+ elseif handler.type == 'getlocal'
+ and checker.type == 'string' then
+ local nodeValue = vm.getObjectValue(handler.node)
+ if nodeValue
+ and nodeValue.type == 'select'
+ and nodeValue.sindex == 1 then
+ local call = nodeValue.vararg
+ if call
+ and call.type == 'call'
+ and call.node.special == 'type'
+ and call.args
+ and call.args[1]
+ and call.args[1].type == 'getlocal'
+ and call.args[1].node == self.source then
+ -- `local tp = type(x);if tp == 'string' then`
+ if action.op.type == '==' then
+ topNode:narrow(self.uri, checker[1])
+ if outNode then
+ outNode:remove(checker[1])
+ end
+ else
+ topNode:remove(checker[1])
+ if outNode then
+ outNode:narrow(self.uri, checker[1])
+ end
+ end
+ end
+ end
+ end
+ end
+ elseif action.type == 'loop'
+ or action.type == 'in'
+ or action.type == 'repeat'
+ or action.type == 'for'
+ or action.type == 'do' then
+ if action[1] then
+ self:lookIntoBlock(action, action.bstart, topNode:copy())
+ local lastAssign = self:getLastAssign(action.start, action.finish)
+ if lastAssign then
+ self:getNode(lastAssign)
+ end
+ if self.nodes[action] then
+ topNode = self.nodes[action]:copy()
+ end
+ end
+ elseif action.type == 'while' then
+ local blockNode, mainNode
+ if action.filter then
+ blockNode, mainNode = self:lookIntoChild(action.filter, topNode:copy(), topNode:copy())
+ else
+ blockNode = topNode:copy()
+ mainNode = topNode:copy()
+ end
+ if action[1] then
+ self:lookIntoBlock(action, action.bstart, blockNode:copy())
+ local lastAssign = self:getLastAssign(action.start, action.finish)
+ if lastAssign then
+ self:getNode(lastAssign)
+ end
+ if self.nodes[action] then
+ topNode = mainNode:merge(self.nodes[action])
+ end
+ end
+ if action.filter then
+ -- look into filter again
+ guide.eachSource(action.filter, function (src)
+ self.mark[src] = nil
+ end)
+ blockNode, topNode = self:lookIntoChild(action.filter, topNode:copy(), topNode:copy())
+ end
+ elseif action.type == 'if' then
+ local hasElse
+ local mainNode = topNode:copy()
+ local blockNodes = {}
+ for _, subBlock in ipairs(action) do
+ self:resetCastsIndex(subBlock.start)
+ local blockNode = mainNode:copy()
+ if subBlock.filter then
+ blockNode, mainNode = self:lookIntoChild(subBlock.filter, blockNode, mainNode)
+ else
+ hasElse = true
+ mainNode:clear()
+ end
+ local mergedNode
+ if subBlock[1] then
+ self:lookIntoBlock(subBlock, subBlock.bstart, blockNode:copy())
+ local neverReturn = subBlock.hasReturn
+ or subBlock.hasGoTo
+ or subBlock.hasBreak
+ or subBlock.hasError
+ if neverReturn then
+ mergedNode = true
+ else
+ local lastAssign = self:getLastAssign(subBlock.start, subBlock.finish)
+ if lastAssign then
+ self:getNode(lastAssign)
+ end
+ if self.nodes[subBlock] then
+ blockNodes[#blockNodes+1] = self.nodes[subBlock]
+ mergedNode = true
+ end
+ end
+ end
+ if not mergedNode then
+ blockNodes[#blockNodes+1] = blockNode
+ end
+ end
+ if not hasElse and not topNode:hasKnownType() then
+ mainNode:merge(vm.declareGlobal('type', 'unknown'))
+ end
+ for _, blockNode in ipairs(blockNodes) do
+ mainNode:merge(blockNode)
+ end
+ topNode = mainNode
+ elseif action.type == 'call' then
+ if action.node.special == 'assert' and action.args and action.args[1] then
+ topNode = self:lookIntoChild(action.args[1], topNode, topNode:copy())
+ end
+ elseif action.type == 'paren' then
+ topNode, outNode = self:lookIntoChild(action.exp, topNode, outNode)
+ elseif action.type == 'setlocal' then
+ if action.value then
+ self:lookIntoChild(action.value, topNode)
+ end
+ elseif action.type == 'local' then
+ if action.value
+ and action.ref
+ and action.value.type == 'select' then
+ local index = action.value.sindex
+ local call = action.value.vararg
+ if index == 1
+ and call.type == 'call'
+ and call.node
+ and call.node.special == 'type'
+ and call.args then
+ local getLoc = call.args[1]
+ if getLoc
+ and getLoc.type == 'getlocal'
+ and getLoc.node == self.source then
+ for _, ref in ipairs(action.ref) do
+ self:collectCare(ref)
+ end
+ end
+ end
+ end
+ end
+ ::RETURN::
+ guide.eachChild(action, function (src)
+ if self.careMap[src] then
+ self:lookIntoChild(src, topNode)
+ end
+ end)
+ return topNode, outNode or topNode
+end
+
+---@param block parser.object
+---@param start integer
+---@param node vm.node
+function mt:lookIntoBlock(block, start, node)
+ self:resetCastsIndex(start)
+ for _, action in ipairs(block) do
+ if (action.effect or action.start) < start then
+ goto CONTINUE
+ end
+ if self.careMap[action] then
+ node = self:lookIntoChild(action, node)
+ end
+ if action.finish > start and self.assignMap[action] then
+ return
+ end
+ ::CONTINUE::
+ end
+ self.nodes[block] = node
+end
+
+---@param source parser.object
+function mt:calcNode(source)
+ if source.type == 'getlocal' then
+ local lastAssign = self:getLastAssign(0, source.start)
+ if not lastAssign then
+ lastAssign = source.node
+ end
+ self:calcNode(lastAssign)
+ return
+ end
+ if source.type == 'local'
+ or source.type == 'self'
+ or source.type == 'setlocal' then
+ local node = vm.compileNode(source)
+ self.nodes[source] = node
+ local parentBlock = guide.getParentBlock(source)
+ if parentBlock then
+ self:lookIntoBlock(parentBlock, source.finish, node)
+ end
+ return
+ end
+end
+
+---@param source parser.object
+---@return vm.node?
+function mt:getNode(source)
+ local cache = self.nodes[source]
+ if cache ~= nil then
+ return cache or nil
+ end
+ if source == self.main then
+ self.nodes[source] = false
+ return nil
+ end
+ self.nodes[source] = false
+ self:calcNode(source)
+ return self.nodes[source] or nil
+end
+
+---@class vm.node
+---@field package _tracer vm.tracer
+
+---@param source parser.object
+---@return vm.tracer?
+local function createTracer(source)
+ local node = vm.compileNode(source)
+ local tracer = node._tracer
+ if tracer then
+ return tracer
+ end
+ local main = guide.getParentBlock(source)
+ if not main then
+ return nil
+ end
+ tracer = setmetatable({
+ source = source,
+ assigns = {},
+ assignMap = {},
+ careMap = {},
+ mark = {},
+ casts = {},
+ nodes = {},
+ main = main,
+ uri = guide.getUri(source),
+ }, mt)
+ node._tracer = tracer
+
+ tracer:collectLocal()
+
+ return tracer
+end
+
+---@param source parser.object
+---@return vm.node?
+function vm.traceNode(source)
+ local loc
+ if source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ loc = source.node
+ end
+ local tracer = createTracer(loc)
+ if not tracer then
+ return nil
+ end
+ local node = tracer:getNode(source)
+ return node
+end
diff --git a/test/tclient/tests/recursive-runner.lua b/test/tclient/tests/recursive-runner.lua
index ddcdb5d6..e824f23a 100644
--- a/test/tclient/tests/recursive-runner.lua
+++ b/test/tclient/tests/recursive-runner.lua
@@ -174,8 +174,7 @@ end
textDocument = { uri = 'file:///test.lua' },
position = { line = 20, character = 11 },
})
- -- TODO
- --assert(hover1.contents.value:find 'vector3')
+ assert(hover1.contents.value:find 'vector3')
config.set(nil, 'Lua.diagnostics.enable', true)
end)
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 33521a0d..0b69a34c 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1590,7 +1590,7 @@ AAA = {}
local <?x?> = AAA()
]]
-TEST 'string|integer' [[
+TEST 'string' [[
local <?x?>
x = '1'
x = 1
@@ -1637,7 +1637,7 @@ function A()
end
]]
-TEST 'unknown' [[
+TEST 'string' [[
local x
function A()
@@ -1758,6 +1758,26 @@ x = '1'
x = 1
]]
+TEST 'integer' [[
+local x
+x = true
+do
+ x = 1
+end
+print(<?x?>)
+]]
+
+TEST 'boolean' [[
+local x
+x = true
+function XX()
+ do
+ x = 1
+ end
+end
+print(<?x?>)
+]]
+
TEST 'integer?' [[
---@type integer?
local <?x?>
@@ -1809,6 +1829,17 @@ end
print(<?x?>)
]]
+TEST 'nil' [[
+---@type integer?
+local x
+
+if not x then
+ print(<?x?>)
+end
+
+print(x)
+]]
+
TEST 'integer' [[
---@type integer?
local x
@@ -1840,6 +1871,15 @@ if xxx and x then
end
]]
+TEST 'unknown' [[
+---@type integer?
+local x
+
+if not x and x then
+ print(<?x?>)
+end
+]]
+
TEST 'integer' [[
---@type integer?
local x
@@ -2277,7 +2317,7 @@ local x
print(<?x?>)
]]
-TEST 'unknown?' [[
+TEST 'nil' [[
---@type string?
local x
@@ -2351,7 +2391,7 @@ end
print(<?t?>)
]]
-TEST 'unknown?' [[
+TEST 'nil' [[
---@type integer?
local t
@@ -3160,7 +3200,7 @@ local function f() end
local x, y, <?z?> = 1, 2, f()
]]
-TEST 'function' [[
+TEST 'unknown' [[
local f
print(<?f?>)
@@ -3168,6 +3208,26 @@ print(<?f?>)
function f() end
]]
+TEST 'unknown' [[
+local f
+
+do
+ print(<?f?>)
+end
+
+function f() end
+]]
+
+TEST 'function' [[
+local f
+
+function A()
+ print(<?f?>)
+end
+
+function f() end
+]]
+
TEST 'number' [[
---@type number|nil
local n
@@ -4000,3 +4060,91 @@ local m, v
local <?r?> = m * v
]]
+
+TEST 'A|B' [[
+---@class A
+---@class B
+
+---@type A|B
+local t
+
+if x then
+ ---@cast t A
+else
+ print(<?t?>)
+end
+]]
+
+TEST 'A|B' [[
+---@class A
+---@class B
+
+---@type A|B
+local t
+
+if x then
+ ---@cast t A
+elseif <?t?> then
+end
+]]
+
+TEST 'A|B' [[
+---@class A
+---@class B
+
+---@type A|B
+local t
+
+if x then
+ ---@cast t A
+ print(t)
+elseif <?t?> then
+end
+]]
+
+TEST 'A|B' [[
+---@class A
+---@class B
+
+---@type A|B
+local t
+
+if x then
+ ---@cast t A
+ print(t)
+elseif <?t?> then
+ ---@cast t A
+ print(t)
+end
+]]
+
+TEST 'function' [[
+local function x()
+ print(<?x?>)
+end
+]]
+
+TEST 'number' [[
+---@type number?
+local x
+
+do
+ if not x then
+ return
+ end
+end
+
+print(<?x?>)
+]]
+
+TEST 'number' [[
+---@type number[]
+local xs
+
+---@type fun(x): number?
+local f
+
+for _, <?x?> in ipairs(xs) do
+ x = f(x)
+end
+]]