diff options
Diffstat (limited to 'script/vm')
-rw-r--r-- | script/vm/compiler.lua | 36 | ||||
-rw-r--r-- | script/vm/doc.lua | 27 | ||||
-rw-r--r-- | script/vm/tracer.lua | 70 | ||||
-rw-r--r-- | script/vm/variable-id.lua | 20 |
4 files changed, 99 insertions, 54 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 060a2173..dbcb403d 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1252,29 +1252,29 @@ local compilerSwitch = util.switch() end end else - if guide.isAssign(source) then - ---@cast key string - vm.compileByParentNode(source.node, key, function (src) - if src.value then - if bindDocs(src) then - vm.setNode(source, vm.compileNode(src)) - elseif src.value.type ~= 'nil' then - vm.setNode(source, vm.compileNode(src.value)) - local node = vm.getNode(src) - if node then - vm.setNode(source, node) - end - end - else - vm.setNode(source, vm.compileNode(src)) - end - end) - else + if guide.isGet(source) then local node = vm.traceNode(source) if node then vm.setNode(source, node) + return end end + ---@cast key string + vm.compileByParentNode(source.node, key, function (src) + if src.value then + if bindDocs(src) then + vm.setNode(source, vm.compileNode(src)) + elseif src.value.type ~= 'nil' then + vm.setNode(source, vm.compileNode(src.value)) + local node = vm.getNode(src) + if node then + vm.setNode(source, node) + end + end + else + vm.setNode(source, vm.compileNode(src)) + end + end) end end) : case 'setglobal' diff --git a/script/vm/doc.lua b/script/vm/doc.lua index b292bc3c..f0f7c54c 100644 --- a/script/vm/doc.lua +++ b/script/vm/doc.lua @@ -4,6 +4,9 @@ local guide = require 'parser.guide' local vm = require 'vm.vm' local config = require 'config' +---@class parser.object +---@field package _castTargetHead parser.object | vm.global | false + ---获取class与alias ---@param suri uri ---@param name? string @@ -414,3 +417,27 @@ function vm.isDiagDisabledAt(uri, position, name, err) end return count > 0 end + +---@param doc parser.object +---@return (parser.object | vm.global)? +function vm.getCastTargetHead(doc) + if doc._castTargetHead ~= nil then + return doc._castTargetHead or nil + end + local name = doc.name[1]:match '^[^%.]+' + if not name then + doc._castTargetHead = false + return nil + end + local loc = guide.getLocal(doc, name, doc.start) + if loc then + doc._castTargetHead = loc + return loc + end + local global = vm.getGlobal('variable', name) + if global then + doc._castTargetHead = global + return global + end + return nil +end diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua index 33af9d0e..65a329ca 100644 --- a/script/vm/tracer.lua +++ b/script/vm/tracer.lua @@ -7,7 +7,10 @@ local util = require 'utility' ---@field package _tracer? vm.tracer ---@field package _casts? parser.object[] +---@alias tracer.mode 'local' | 'global' + ---@class vm.tracer +---@field mode tracer.mode ---@field name string ---@field source parser.object ---@field assigns parser.object[] @@ -94,21 +97,24 @@ function mt:collectLocal() self.assigns[#self.assigns+1] = self.source self.assignMap[self.source] = true - for _, obj in ipairs(self.source.ref) do - if obj.type == 'setlocal' then - self.assigns[#self.assigns+1] = obj - self.assignMap[obj] = true - self:collectCare(obj) - if obj.finish > finishPos then - finishPos = obj.finish - end + local varInfo = vm.getVariableInfoByName(self.source, self.name) + + assert(varInfo) + + for _, set in ipairs(varInfo.sets) do + self.assigns[#self.assigns+1] = set + self.assignMap[set] = true + self:collectCare(set) + if set.finish > finishPos then + finishPos = set.finish end - if obj.type == 'getlocal' then - self:collectCare(obj) - self.getMap[obj] = true - if obj.finish > finishPos then - finishPos = obj.finish - end + end + + for _, get in ipairs(varInfo.gets) do + self:collectCare(get) + self.getMap[get] = true + if get.finish > finishPos then + finishPos = get.finish end end @@ -117,7 +123,7 @@ function mt:collectLocal() if cast.name[1] == self.name and cast.start > startPos and cast.finish < finishPos - and guide.getLocal(self.source, self.name, cast.start) == self.source then + and vm.getCastTargetHead(cast) == self.source then self.casts[#self.casts+1] = cast end end @@ -128,9 +134,6 @@ function mt:collectLocal() end function mt:collectGlobal() - local startPos = 0 - local finishPos = 0 - self.assigns[#self.assigns+1] = self.source self.assignMap[self.source] = true @@ -142,23 +145,20 @@ function mt:collectGlobal() self.assigns[#self.assigns+1] = set self.assignMap[set] = true self:collectCare(set) - if set.finish > finishPos then - finishPos = set.finish - end end for _, get in ipairs(link.gets) do self:collectCare(get) self.getMap[get] = true - if get.finish > finishPos then - finishPos = get.finish - end end local casts = self:getCasts() for _, cast in ipairs(casts) do if cast.name[1] == self.name then - self.casts[#self.casts+1] = cast + local castTarget = vm.getCastTargetHead(cast) + if castTarget and castTarget.type == 'global' then + self.casts[#self.casts+1] = cast + end end end @@ -780,10 +780,11 @@ end ---@class vm.node ---@field package _tracer vm.tracer +---@param mode tracer.mode ---@param source parser.object ---@param name string ---@return vm.tracer? -local function createTracer(source, name) +local function createTracer(mode, source, name) local node = vm.compileNode(source) local tracer = node._tracer if tracer then @@ -795,6 +796,7 @@ local function createTracer(source, name) end tracer = setmetatable({ source = source, + mode = mode, name = name, assigns = {}, assignMap = {}, @@ -808,8 +810,7 @@ local function createTracer(source, name) }, mt) node._tracer = tracer - if source.type == 'local' - or source.type == 'self' then + if tracer.mode == 'local' then tracer:collectLocal() else tracer:collectGlobal() @@ -821,17 +822,13 @@ end ---@param source parser.object ---@return vm.node? function vm.traceNode(source) - local base, name - if source.type == 'getlocal' - or source.type == 'setlocal' then - base = source.node - ---@type string - name = source[1] - elseif vm.getGlobalNode(source) then + local mode, base, name + if vm.getGlobalNode(source) then base = vm.getGlobalBase(source) if not base then return nil end + mode = 'global' name = base.global:getCodeName() else base = vm.getVariableHead(source) @@ -842,8 +839,9 @@ function vm.traceNode(source) if not name then return nil end + mode = 'local' end - local tracer = createTracer(base, name) + local tracer = createTracer(mode, base, name) if not tracer then return nil end diff --git a/script/vm/variable-id.lua b/script/vm/variable-id.lua index a2aece65..4a441958 100644 --- a/script/vm/variable-id.lua +++ b/script/vm/variable-id.lua @@ -202,6 +202,26 @@ function vm.getVariableInfo(source, key) end ---@param source parser.object +---@param name string +---@return vm.variable? +function vm.getVariableInfoByName(source, name) + local id = vm.getVariableID(source) + if not id then + return nil + end + local root = guide.getRoot(source) + if not root._variableIDs then + return nil + end + local headPos = name:find('.', 1, true) + if not headPos then + return root._variableIDs[id] + end + local vid = id .. name:sub(headPos):gsub('%.', vm.ID_SPLITE) + return root._variableIDs[vid] +end + +---@param source parser.object ---@param key? string ---@return parser.object[]? function vm.getVariableSets(source, key) |