summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/vm/compiler.lua36
-rw-r--r--script/vm/doc.lua27
-rw-r--r--script/vm/tracer.lua70
-rw-r--r--script/vm/variable-id.lua20
4 files changed, 99 insertions, 54 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 060a2173..dbcb403d 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1252,29 +1252,29 @@ local compilerSwitch = util.switch()
end
end
else
- if guide.isAssign(source) then
- ---@cast key string
- vm.compileByParentNode(source.node, key, function (src)
- if src.value then
- if bindDocs(src) then
- vm.setNode(source, vm.compileNode(src))
- elseif src.value.type ~= 'nil' then
- vm.setNode(source, vm.compileNode(src.value))
- local node = vm.getNode(src)
- if node then
- vm.setNode(source, node)
- end
- end
- else
- vm.setNode(source, vm.compileNode(src))
- end
- end)
- else
+ if guide.isGet(source) then
local node = vm.traceNode(source)
if node then
vm.setNode(source, node)
+ return
end
end
+ ---@cast key string
+ vm.compileByParentNode(source.node, key, function (src)
+ if src.value then
+ if bindDocs(src) then
+ vm.setNode(source, vm.compileNode(src))
+ elseif src.value.type ~= 'nil' then
+ vm.setNode(source, vm.compileNode(src.value))
+ local node = vm.getNode(src)
+ if node then
+ vm.setNode(source, node)
+ end
+ end
+ else
+ vm.setNode(source, vm.compileNode(src))
+ end
+ end)
end
end)
: case 'setglobal'
diff --git a/script/vm/doc.lua b/script/vm/doc.lua
index b292bc3c..f0f7c54c 100644
--- a/script/vm/doc.lua
+++ b/script/vm/doc.lua
@@ -4,6 +4,9 @@ local guide = require 'parser.guide'
local vm = require 'vm.vm'
local config = require 'config'
+---@class parser.object
+---@field package _castTargetHead parser.object | vm.global | false
+
---获取class与alias
---@param suri uri
---@param name? string
@@ -414,3 +417,27 @@ function vm.isDiagDisabledAt(uri, position, name, err)
end
return count > 0
end
+
+---@param doc parser.object
+---@return (parser.object | vm.global)?
+function vm.getCastTargetHead(doc)
+ if doc._castTargetHead ~= nil then
+ return doc._castTargetHead or nil
+ end
+ local name = doc.name[1]:match '^[^%.]+'
+ if not name then
+ doc._castTargetHead = false
+ return nil
+ end
+ local loc = guide.getLocal(doc, name, doc.start)
+ if loc then
+ doc._castTargetHead = loc
+ return loc
+ end
+ local global = vm.getGlobal('variable', name)
+ if global then
+ doc._castTargetHead = global
+ return global
+ end
+ return nil
+end
diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua
index 33af9d0e..65a329ca 100644
--- a/script/vm/tracer.lua
+++ b/script/vm/tracer.lua
@@ -7,7 +7,10 @@ local util = require 'utility'
---@field package _tracer? vm.tracer
---@field package _casts? parser.object[]
+---@alias tracer.mode 'local' | 'global'
+
---@class vm.tracer
+---@field mode tracer.mode
---@field name string
---@field source parser.object
---@field assigns parser.object[]
@@ -94,21 +97,24 @@ function mt:collectLocal()
self.assigns[#self.assigns+1] = self.source
self.assignMap[self.source] = true
- for _, obj in ipairs(self.source.ref) do
- if obj.type == 'setlocal' then
- self.assigns[#self.assigns+1] = obj
- self.assignMap[obj] = true
- self:collectCare(obj)
- if obj.finish > finishPos then
- finishPos = obj.finish
- end
+ local varInfo = vm.getVariableInfoByName(self.source, self.name)
+
+ assert(varInfo)
+
+ for _, set in ipairs(varInfo.sets) do
+ self.assigns[#self.assigns+1] = set
+ self.assignMap[set] = true
+ self:collectCare(set)
+ if set.finish > finishPos then
+ finishPos = set.finish
end
- if obj.type == 'getlocal' then
- self:collectCare(obj)
- self.getMap[obj] = true
- if obj.finish > finishPos then
- finishPos = obj.finish
- end
+ end
+
+ for _, get in ipairs(varInfo.gets) do
+ self:collectCare(get)
+ self.getMap[get] = true
+ if get.finish > finishPos then
+ finishPos = get.finish
end
end
@@ -117,7 +123,7 @@ function mt:collectLocal()
if cast.name[1] == self.name
and cast.start > startPos
and cast.finish < finishPos
- and guide.getLocal(self.source, self.name, cast.start) == self.source then
+ and vm.getCastTargetHead(cast) == self.source then
self.casts[#self.casts+1] = cast
end
end
@@ -128,9 +134,6 @@ function mt:collectLocal()
end
function mt:collectGlobal()
- local startPos = 0
- local finishPos = 0
-
self.assigns[#self.assigns+1] = self.source
self.assignMap[self.source] = true
@@ -142,23 +145,20 @@ function mt:collectGlobal()
self.assigns[#self.assigns+1] = set
self.assignMap[set] = true
self:collectCare(set)
- if set.finish > finishPos then
- finishPos = set.finish
- end
end
for _, get in ipairs(link.gets) do
self:collectCare(get)
self.getMap[get] = true
- if get.finish > finishPos then
- finishPos = get.finish
- end
end
local casts = self:getCasts()
for _, cast in ipairs(casts) do
if cast.name[1] == self.name then
- self.casts[#self.casts+1] = cast
+ local castTarget = vm.getCastTargetHead(cast)
+ if castTarget and castTarget.type == 'global' then
+ self.casts[#self.casts+1] = cast
+ end
end
end
@@ -780,10 +780,11 @@ end
---@class vm.node
---@field package _tracer vm.tracer
+---@param mode tracer.mode
---@param source parser.object
---@param name string
---@return vm.tracer?
-local function createTracer(source, name)
+local function createTracer(mode, source, name)
local node = vm.compileNode(source)
local tracer = node._tracer
if tracer then
@@ -795,6 +796,7 @@ local function createTracer(source, name)
end
tracer = setmetatable({
source = source,
+ mode = mode,
name = name,
assigns = {},
assignMap = {},
@@ -808,8 +810,7 @@ local function createTracer(source, name)
}, mt)
node._tracer = tracer
- if source.type == 'local'
- or source.type == 'self' then
+ if tracer.mode == 'local' then
tracer:collectLocal()
else
tracer:collectGlobal()
@@ -821,17 +822,13 @@ end
---@param source parser.object
---@return vm.node?
function vm.traceNode(source)
- local base, name
- if source.type == 'getlocal'
- or source.type == 'setlocal' then
- base = source.node
- ---@type string
- name = source[1]
- elseif vm.getGlobalNode(source) then
+ local mode, base, name
+ if vm.getGlobalNode(source) then
base = vm.getGlobalBase(source)
if not base then
return nil
end
+ mode = 'global'
name = base.global:getCodeName()
else
base = vm.getVariableHead(source)
@@ -842,8 +839,9 @@ function vm.traceNode(source)
if not name then
return nil
end
+ mode = 'local'
end
- local tracer = createTracer(base, name)
+ local tracer = createTracer(mode, base, name)
if not tracer then
return nil
end
diff --git a/script/vm/variable-id.lua b/script/vm/variable-id.lua
index a2aece65..4a441958 100644
--- a/script/vm/variable-id.lua
+++ b/script/vm/variable-id.lua
@@ -202,6 +202,26 @@ function vm.getVariableInfo(source, key)
end
---@param source parser.object
+---@param name string
+---@return vm.variable?
+function vm.getVariableInfoByName(source, name)
+ local id = vm.getVariableID(source)
+ if not id then
+ return nil
+ end
+ local root = guide.getRoot(source)
+ if not root._variableIDs then
+ return nil
+ end
+ local headPos = name:find('.', 1, true)
+ if not headPos then
+ return root._variableIDs[id]
+ end
+ local vid = id .. name:sub(headPos):gsub('%.', vm.ID_SPLITE)
+ return root._variableIDs[vid]
+end
+
+---@param source parser.object
---@param key? string
---@return parser.object[]?
function vm.getVariableSets(source, key)