From 98589ddbeba3cb7e9341e79bb674bad9e8b957c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Mon, 12 Dec 2022 21:09:43 +0800 Subject: stash --- script/vm/compiler.lua | 86 ++++---------------------- script/vm/init.lua | 2 +- script/vm/tracer.lua | 144 +++++++++++++++++++++++++++++++++++++++++++ test/type_inference/init.lua | 2 +- 4 files changed, 159 insertions(+), 75 deletions(-) create mode 100644 script/vm/tracer.lua diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 10bd5221..a7ad4464 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -973,14 +973,11 @@ local function compileLocal(source) end end if not hasMarkValue and not hasMarkValue then - if source.ref then - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' - and ref.value - and ref.value.type == 'function' then - vm.setNode(source, vm.compileNode(ref.value)) - end - end + local firstRef = source.ref and source.ref[1] + if firstRef + and guide.isSet(firstRef) + and guide.getBlock(firstRef) == guide.getBlock(source) then + vm.setNode(source, vm.compileNode(firstRef)) end end -- function x.y(self, ...) --> function x:y(...) @@ -1163,73 +1160,12 @@ local compilerSwitch = util.switch() ---@async ---@param source parser.object : call(function (source) - vm.launchRunner(source, function () - local myNode = vm.getNode(source) - ---@cast myNode -? - myNode:setData('resolving', true) - - if source.ref then - for _, ref in ipairs(source.ref) do - if ref.type == 'getlocal' - or ref.type == 'setlocal' then - vm.setNode(ref, myNode, true) - end - end - end - compileLocal(source) - - myNode.resolved = true - end, function () - local myNode = vm.getNode(source) - ---@cast myNode -? - myNode:setData('resolving', nil) - local hasMark = vm.getNode(source):getData 'hasDefined' - if source.ref and not hasMark then - local parentFunc = guide.getParentFunction(source) - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' - and guide.getParentFunction(ref) == parentFunc then - local refNode = vm.getNode(ref) - if refNode then - vm.setNode(source, refNode) - end - end - end - end - end, function (src, node) - if src.type == 'setlocal' then - if src.bindDocs then - for _, doc in ipairs(src.bindDocs) do - if doc.type == 'doc.type' then - vm.setNode(src, vm.compileNode(doc), true) - return vm.getNode(src) - end - end - end - if src.value then - if src.value.type == 'table' then - vm.setNode(src, vm.createNode(src.value), true) - vm.setNode(src, node:copy():asTable()) - else - vm.setNode(src, vm.compileNode(src.value), true) - end - else - vm.setNode(src, node, true) - end - return vm.getNode(src) - elseif src.type == 'getlocal' then - if bindAs(src) then - return - end - vm.setNode(src, node, true) - node.resolved = true - matchCall(src) - end - end) + compileLocal(source) end) : case 'setlocal' : call(function (source) - vm.compileNode(source.node) + local valueNode = vm.compileNode(source.value) + vm.setNode(source, valueNode) end) : case 'getlocal' ---@async @@ -1237,7 +1173,11 @@ local compilerSwitch = util.switch() if bindAs(source) then return end - vm.compileNode(source.node) + local node = vm.traceNode(source) + if not node then + return + end + vm.setNode(source, node, true) end) : case 'setfield' : case 'setmethod' diff --git a/script/vm/init.lua b/script/vm/init.lua index 7b69a7eb..9c8ebe55 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -11,7 +11,7 @@ require 'vm.field' require 'vm.doc' require 'vm.type' require 'vm.library' -require 'vm.runner' +require 'vm.tracer' require 'vm.infer' require 'vm.generic' require 'vm.sign' diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua new file mode 100644 index 00000000..ba83fd5b --- /dev/null +++ b/script/vm/tracer.lua @@ -0,0 +1,144 @@ +---@class vm +local vm = require 'vm.vm' +local guide = require 'parser.guide' + +---@class parser.object +---@field package _tracer? vm.tracer +---@field package _casts? parser.object[] + +---@class vm.tracer +---@field source parser.object +---@field assigns parser.object[] +---@field nodes table +---@field main parser.object +---@field uri uri +local mt = {} +mt.__index = mt + +---@return parser.object[] +function mt:getCasts() + local root = guide.getRoot(self.source) + if not root._casts then + root._casts = {} + local docs = root.docs + for _, doc in ipairs(docs) do + if doc.type == 'doc.cast' and doc.loc then + root._casts[#root._casts+1] = doc + end + end + end + return root._casts +end + +---@param obj parser.object +---@param mark table +function mt:collectBlock(obj, mark) + while true do + if mark[obj] then + return + end + mark[obj] = true + self.assigns[#self.assigns+1] = obj + if obj == self.main then + return + end + obj = obj.parent + end +end + +function mt:collectLocal() + local startPos = self.source.start + local finishPos = 0 + + local mark = {} + + for _, obj in ipairs(self.source.ref) do + if obj.type == 'setlocal' then + self.assigns[#self.assigns+1] = obj + self:collectBlock(obj, mark) + end + end + + local casts = self:getCasts() + for _, cast in ipairs(casts) do + if cast.loc[1] == self.source[1] + and cast.start > startPos + and cast.finish < finishPos + and guide.getLocal(self.source, self.source[1], cast.start) == self.source then + self.assigns[#self.assigns+1] = cast + end + end + + table.sort(self.assigns, function (a, b) + return a.start < b.start + end) +end + +---@param source parser.object +---@return parser.object? +function mt:getLastAssign(source) + local assign = self.source + for _, obj in ipairs(self.assigns) do + if obj.start > source.start then + break + end + assign = obj + end + return assign +end + +---@param source parser.object +---@return vm.node? +function mt:getNode(source) + if self.nodes[source] then + return self.nodes[source] + end + local lastAssign = self:getLastAssign(source) + if not lastAssign then + return nil + end + if guide.isSet(lastAssign) then + local lastNode = vm.compileNode(lastAssign) + return lastNode + end +end + +---@param source parser.object +---@return vm.tracer? +local function createTracer(source) + if source._tracer then + return source._tracer + end + local main = guide.getParentBlock(source) + if not main then + return nil + end + local tracer = setmetatable({ + source = source, + assigns = {}, + nodes = {}, + main = main, + uri = guide.getUri(source), + }, mt) + source._tracer = tracer + + tracer:collectLocal() + + return tracer +end + +---@param source parser.object +---@return vm.node? +function vm.traceNode(source) + local loc + if source.type == 'getlocal' + or source.type == 'setlocal' then + loc = source.node + end + local tracer = createTracer(loc) + if not tracer then + return nil + end + local node = tracer:getNode(source) + return node +end diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 33521a0d..02b3da03 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1590,7 +1590,7 @@ AAA = {} local = AAA() ]] -TEST 'string|integer' [[ +TEST 'string' [[ local x = '1' x = 1 -- cgit v1.2.3