summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/vm/compiler.lua86
-rw-r--r--script/vm/init.lua2
-rw-r--r--script/vm/tracer.lua144
-rw-r--r--test/type_inference/init.lua2
4 files changed, 159 insertions, 75 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 10bd5221..a7ad4464 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -973,14 +973,11 @@ local function compileLocal(source)
end
end
if not hasMarkValue and not hasMarkValue then
- if source.ref then
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal'
- and ref.value
- and ref.value.type == 'function' then
- vm.setNode(source, vm.compileNode(ref.value))
- end
- end
+ local firstRef = source.ref and source.ref[1]
+ if firstRef
+ and guide.isSet(firstRef)
+ and guide.getBlock(firstRef) == guide.getBlock(source) then
+ vm.setNode(source, vm.compileNode(firstRef))
end
end
-- function x.y(self, ...) --> function x:y(...)
@@ -1163,73 +1160,12 @@ local compilerSwitch = util.switch()
---@async
---@param source parser.object
: call(function (source)
- vm.launchRunner(source, function ()
- local myNode = vm.getNode(source)
- ---@cast myNode -?
- myNode:setData('resolving', true)
-
- if source.ref then
- for _, ref in ipairs(source.ref) do
- if ref.type == 'getlocal'
- or ref.type == 'setlocal' then
- vm.setNode(ref, myNode, true)
- end
- end
- end
- compileLocal(source)
-
- myNode.resolved = true
- end, function ()
- local myNode = vm.getNode(source)
- ---@cast myNode -?
- myNode:setData('resolving', nil)
- local hasMark = vm.getNode(source):getData 'hasDefined'
- if source.ref and not hasMark then
- local parentFunc = guide.getParentFunction(source)
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal'
- and guide.getParentFunction(ref) == parentFunc then
- local refNode = vm.getNode(ref)
- if refNode then
- vm.setNode(source, refNode)
- end
- end
- end
- end
- end, function (src, node)
- if src.type == 'setlocal' then
- if src.bindDocs then
- for _, doc in ipairs(src.bindDocs) do
- if doc.type == 'doc.type' then
- vm.setNode(src, vm.compileNode(doc), true)
- return vm.getNode(src)
- end
- end
- end
- if src.value then
- if src.value.type == 'table' then
- vm.setNode(src, vm.createNode(src.value), true)
- vm.setNode(src, node:copy():asTable())
- else
- vm.setNode(src, vm.compileNode(src.value), true)
- end
- else
- vm.setNode(src, node, true)
- end
- return vm.getNode(src)
- elseif src.type == 'getlocal' then
- if bindAs(src) then
- return
- end
- vm.setNode(src, node, true)
- node.resolved = true
- matchCall(src)
- end
- end)
+ compileLocal(source)
end)
: case 'setlocal'
: call(function (source)
- vm.compileNode(source.node)
+ local valueNode = vm.compileNode(source.value)
+ vm.setNode(source, valueNode)
end)
: case 'getlocal'
---@async
@@ -1237,7 +1173,11 @@ local compilerSwitch = util.switch()
if bindAs(source) then
return
end
- vm.compileNode(source.node)
+ local node = vm.traceNode(source)
+ if not node then
+ return
+ end
+ vm.setNode(source, node, true)
end)
: case 'setfield'
: case 'setmethod'
diff --git a/script/vm/init.lua b/script/vm/init.lua
index 7b69a7eb..9c8ebe55 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -11,7 +11,7 @@ require 'vm.field'
require 'vm.doc'
require 'vm.type'
require 'vm.library'
-require 'vm.runner'
+require 'vm.tracer'
require 'vm.infer'
require 'vm.generic'
require 'vm.sign'
diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua
new file mode 100644
index 00000000..ba83fd5b
--- /dev/null
+++ b/script/vm/tracer.lua
@@ -0,0 +1,144 @@
+---@class vm
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+
+---@class parser.object
+---@field package _tracer? vm.tracer
+---@field package _casts? parser.object[]
+
+---@class vm.tracer
+---@field source parser.object
+---@field assigns parser.object[]
+---@field nodes table<parser.object, vm.node>
+---@field main parser.object
+---@field uri uri
+local mt = {}
+mt.__index = mt
+
+---@return parser.object[]
+function mt:getCasts()
+ local root = guide.getRoot(self.source)
+ if not root._casts then
+ root._casts = {}
+ local docs = root.docs
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.cast' and doc.loc then
+ root._casts[#root._casts+1] = doc
+ end
+ end
+ end
+ return root._casts
+end
+
+---@param obj parser.object
+---@param mark table
+function mt:collectBlock(obj, mark)
+ while true do
+ if mark[obj] then
+ return
+ end
+ mark[obj] = true
+ self.assigns[#self.assigns+1] = obj
+ if obj == self.main then
+ return
+ end
+ obj = obj.parent
+ end
+end
+
+function mt:collectLocal()
+ local startPos = self.source.start
+ local finishPos = 0
+
+ local mark = {}
+
+ for _, obj in ipairs(self.source.ref) do
+ if obj.type == 'setlocal' then
+ self.assigns[#self.assigns+1] = obj
+ self:collectBlock(obj, mark)
+ end
+ end
+
+ local casts = self:getCasts()
+ for _, cast in ipairs(casts) do
+ if cast.loc[1] == self.source[1]
+ and cast.start > startPos
+ and cast.finish < finishPos
+ and guide.getLocal(self.source, self.source[1], cast.start) == self.source then
+ self.assigns[#self.assigns+1] = cast
+ end
+ end
+
+ table.sort(self.assigns, function (a, b)
+ return a.start < b.start
+ end)
+end
+
+---@param source parser.object
+---@return parser.object?
+function mt:getLastAssign(source)
+ local assign = self.source
+ for _, obj in ipairs(self.assigns) do
+ if obj.start > source.start then
+ break
+ end
+ assign = obj
+ end
+ return assign
+end
+
+---@param source parser.object
+---@return vm.node?
+function mt:getNode(source)
+ if self.nodes[source] then
+ return self.nodes[source]
+ end
+ local lastAssign = self:getLastAssign(source)
+ if not lastAssign then
+ return nil
+ end
+ if guide.isSet(lastAssign) then
+ local lastNode = vm.compileNode(lastAssign)
+ return lastNode
+ end
+end
+
+---@param source parser.object
+---@return vm.tracer?
+local function createTracer(source)
+ if source._tracer then
+ return source._tracer
+ end
+ local main = guide.getParentBlock(source)
+ if not main then
+ return nil
+ end
+ local tracer = setmetatable({
+ source = source,
+ assigns = {},
+ nodes = {},
+ main = main,
+ uri = guide.getUri(source),
+ }, mt)
+ source._tracer = tracer
+
+ tracer:collectLocal()
+
+ return tracer
+end
+
+---@param source parser.object
+---@return vm.node?
+function vm.traceNode(source)
+ local loc
+ if source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ loc = source.node
+ end
+ local tracer = createTracer(loc)
+ if not tracer then
+ return nil
+ end
+ local node = tracer:getNode(source)
+ return node
+end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 33521a0d..02b3da03 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1590,7 +1590,7 @@ AAA = {}
local <?x?> = AAA()
]]
-TEST 'string|integer' [[
+TEST 'string' [[
local <?x?>
x = '1'
x = 1