summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/vm/compiler.lua28
-rw-r--r--script/vm/infer.lua10
-rw-r--r--script/vm/node.lua7
-rw-r--r--script/vm/tracer.lua78
-rw-r--r--test/type_inference/init.lua28
5 files changed, 127 insertions, 24 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index a7ad4464..eefed5a8 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -972,12 +972,25 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(source.value))
end
end
- if not hasMarkValue and not hasMarkValue then
- local firstRef = source.ref and source.ref[1]
- if firstRef
- and guide.isSet(firstRef)
- and guide.getBlock(firstRef) == guide.getBlock(source) then
- vm.setNode(source, vm.compileNode(firstRef))
+ if not hasMarkValue
+ and not hasMarkValue
+ and source.ref then
+ local firstSet
+ local myFunction = guide.getParentFunction(source)
+ for _, ref in ipairs(source.ref) do
+ if ref.type == 'setlocal' then
+ firstSet = ref
+ break
+ end
+ if ref.type == 'getlocal' then
+ if guide.getParentFunction(ref) == myFunction then
+ break
+ end
+ end
+ end
+ if firstSet
+ and guide.getBlock(firstSet) == guide.getBlock(source) then
+ vm.setNode(source, vm.compileNode(firstSet))
end
end
-- function x.y(self, ...) --> function x:y(...)
@@ -1164,6 +1177,9 @@ local compilerSwitch = util.switch()
end)
: case 'setlocal'
: call(function (source)
+ if bindDocs(source) then
+ return
+ end
local valueNode = vm.compileNode(source.value)
vm.setNode(source, valueNode)
end)
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index b9dfb29a..99cf622e 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -432,10 +432,14 @@ function mt:view(uri, default)
end
if self.node:isOptional() then
- if max > 1 then
- view = '(' .. view .. ')?'
+ if #array == 0 then
+ view = 'nil'
else
- view = view .. '?'
+ if max > 1 then
+ view = '(' .. view .. ')?'
+ else
+ view = view .. '?'
+ end
end
end
diff --git a/script/vm/node.lua b/script/vm/node.lua
index d0fd5ffb..2e408128 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -188,6 +188,9 @@ end
---@return vm.node
function mt:setFalsy()
+ if self.optional == false then
+ self.optional = nil
+ end
local hasBoolean
for index = #self, 1, -1 do
local c = self[index]
@@ -226,10 +229,6 @@ function mt:setFalsy()
if hasBoolean then
self:merge(vm.declareGlobal('type', 'false'))
end
- if self.optional then
- self.optional = nil
- self:merge(vm.declareGlobal('type', 'nil'))
- end
return self
end
diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua
index c2b7ae39..5b3b1b54 100644
--- a/script/vm/tracer.lua
+++ b/script/vm/tracer.lua
@@ -13,9 +13,11 @@ local util = require 'utility'
---@field assignMap table<parser.object, true>
---@field careMap table<parser.object, true>
---@field mark table<parser.object, true>
+---@field casts parser.object[]
---@field nodes table<parser.object, vm.node|false>
---@field main parser.object
---@field uri uri
+---@field castIndex integer?
local mt = {}
mt.__index = mt
@@ -79,25 +81,27 @@ function mt:collectLocal()
self.assigns[#self.assigns+1] = obj
self.assignMap[obj] = true
self:collectCare(obj)
+ if obj.finish > finishPos then
+ finishPos = obj.finish
+ end
end
if obj.type == 'getlocal' then
self:collectCare(obj)
+ if obj.finish > finishPos then
+ finishPos = obj.finish
+ end
end
end
local casts = self:getCasts()
for _, cast in ipairs(casts) do
if cast.loc[1] == self.source[1]
- and cast.start > startPos
+ and cast.start > startPos
and cast.finish < finishPos
and guide.getLocal(self.source, self.source[1], cast.start) == self.source then
- self.assigns[#self.assigns+1] = cast
+ self.casts[#self.casts+1] = cast
end
end
-
- table.sort(self.assigns, function (a, b)
- return a.start < b.start
- end)
end
---@param start integer
@@ -109,7 +113,7 @@ function mt:getLastAssign(start, finish)
if obj.start < start then
goto CONTINUE
end
- if obj.start >= finish then
+ if (obj.range or obj.start) >= finish then
break
end
local objBlock = guide.getParentBlock(obj)
@@ -125,6 +129,58 @@ function mt:getLastAssign(start, finish)
return assign
end
+---@param pos integer
+function mt:resetCastsIndex(pos)
+ for i = 1, #self.casts do
+ local cast = self.casts[i]
+ if cast.start > pos then
+ self.castIndex = i
+ return
+ end
+ end
+ self.castIndex = nil
+end
+
+---@param pos integer
+---@param node vm.node
+---@return vm.node
+function mt:fastWardCasts(pos, node)
+ if not self.castIndex then
+ return node
+ end
+ for i = self.castIndex, #self.casts do
+ local action = self.casts[i]
+ if action.start > pos then
+ return node
+ end
+ node = node:copy()
+ for _, cast in ipairs(action.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ node:addOptional()
+ end
+ if cast.extends then
+ node:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ node:removeOptional()
+ end
+ if cast.extends then
+ node:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ node:clear()
+ node:merge(vm.compileNode(cast.extends))
+ end
+ end
+ end
+ end
+ self.castIndex = self.castIndex + 1
+ return node
+end
+
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
@@ -136,6 +192,7 @@ function mt:lookIntoChild(action, topNode, outNode)
return topNode, outNode or topNode
end
self.mark[action] = true
+ topNode = self:fastWardCasts(action.start, topNode)
if action.type == 'getlocal' then
if action.node == self.source then
self.nodes[action] = topNode
@@ -306,13 +363,18 @@ function mt:lookIntoChild(action, topNode, outNode)
or subBlock.hasBreak
or subBlock.hasError
if not neverReturn then
+ local ok
local lastAssign = self:getLastAssign(subBlock.start, subBlock.finish)
if lastAssign then
local node = self:getNode(lastAssign)
if node then
blockNodes[#blockNodes+1] = node
+ ok = true
end
end
+ if not ok then
+ blockNodes[#blockNodes+1] = blockNode
+ end
end
end
if not hasElse and not topNode:hasKnownType() then
@@ -367,6 +429,7 @@ end
---@param start integer
---@param node vm.node
function mt:lookIntoBlock(block, start, node)
+ self:resetCastsIndex(start)
for _, action in ipairs(block) do
if action.start < start then
goto CONTINUE
@@ -436,6 +499,7 @@ local function createTracer(source)
assignMap = {},
careMap = {},
mark = {},
+ casts = {},
nodes = {},
main = main,
uri = guide.getUri(source),
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 8ae65e48..0eb43ed3 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1637,7 +1637,7 @@ function A()
end
]]
-TEST 'unknown' [[
+TEST 'string' [[
local x
function A()
@@ -2317,7 +2317,7 @@ local x
print(<?x?>)
]]
-TEST 'unknown?' [[
+TEST 'nil' [[
---@type string?
local x
@@ -2391,7 +2391,7 @@ end
print(<?t?>)
]]
-TEST 'unknown?' [[
+TEST 'nil' [[
---@type integer?
local t
@@ -3200,7 +3200,7 @@ local function f() end
local x, y, <?z?> = 1, 2, f()
]]
-TEST 'function' [[
+TEST 'unknown' [[
local f
print(<?f?>)
@@ -3208,6 +3208,26 @@ print(<?f?>)
function f() end
]]
+TEST 'unknown' [[
+local f
+
+do
+ print(<?f?>)
+end
+
+function f() end
+]]
+
+TEST 'function' [[
+local f
+
+function A()
+ print(<?f?>)
+end
+
+function f() end
+]]
+
TEST 'number' [[
---@type number|nil
local n