summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/vm/compiler.lua6
-rw-r--r--script/vm/def.lua12
-rw-r--r--script/vm/field.lua2
-rw-r--r--script/vm/global.lua2
-rw-r--r--script/vm/local-id.lua50
-rw-r--r--script/vm/ref.lua15
6 files changed, 62 insertions, 25 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index cd2b602d..797fa901 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -59,9 +59,9 @@ local searchFieldSwitch = util.switch()
: call(function (suri, node, key, ref, pushResult)
local fields
if key then
- fields = vm.getLocalSources(node, key)
+ fields = vm.getLocalSourcesSets(node, key)
else
- fields = vm.getLocalFields(node)
+ fields = vm.getLocalFields(node, false)
end
if fields then
for _, src in ipairs(fields) do
@@ -586,7 +586,7 @@ local function bindDocs(source)
end
local function compileByLocalID(source)
- local sources = vm.getLocalSources(source)
+ local sources = vm.getLocalSourcesSets(source)
if not sources then
return
end
diff --git a/script/vm/def.lua b/script/vm/def.lua
index 83e92686..ad343ae6 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -115,12 +115,10 @@ local searchFieldSwitch = util.switch()
end)
: case 'local'
: call(function (suri, obj, key, pushResult)
- local sources = vm.getLocalSources(obj, key)
+ local sources = vm.getLocalSourcesSets(obj, key)
if sources then
for _, src in ipairs(sources) do
- if guide.isSet(src) then
- pushResult(src)
- end
+ pushResult(src)
end
end
end)
@@ -194,14 +192,12 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
- local idSources = vm.getLocalSources(source)
+ local idSources = vm.getLocalSourcesSets(source)
if not idSources then
return
end
for _, src in ipairs(idSources) do
- if guide.isSet(src) then
- pushResult(src)
- end
+ pushResult(src)
end
end
diff --git a/script/vm/field.lua b/script/vm/field.lua
index 5de838be..b92c3a7b 100644
--- a/script/vm/field.lua
+++ b/script/vm/field.lua
@@ -16,7 +16,7 @@ local searchByNodeSwitch = util.switch()
end)
local function searchByLocalID(source, pushResult)
- local fields = vm.getLocalFields(source)
+ local fields = vm.getLocalFields(source, true)
if fields then
for _, field in ipairs(fields) do
pushResult(field)
diff --git a/script/vm/global.lua b/script/vm/global.lua
index a54ab552..81791d14 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -446,7 +446,7 @@ local function compileSelf(source)
if not node then
return
end
- local fields = vm.getLocalFields(source)
+ local fields = vm.getLocalFields(source, false)
if not fields then
return
end
diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua
index 80c68769..bb12a927 100644
--- a/script/vm/local-id.lua
+++ b/script/vm/local-id.lua
@@ -5,7 +5,7 @@ local vm = require 'vm.vm'
---@class parser.object
---@field _localID string
----@field _localIDs table<string, parser.object[]>
+---@field _localIDs table<string, { sets: parser.object[], gets: parser.object[] }[]>
local compileLocalID, getLocal
@@ -114,10 +114,19 @@ end
function vm.insertLocalID(id, source)
local root = guide.getRoot(source)
if not root._localIDs then
- root._localIDs = util.multiTable(2)
+ root._localIDs = util.multiTable(2, function ()
+ return {
+ sets = {},
+ gets = {},
+ }
+ end)
end
local sources = root._localIDs[id]
- sources[#sources+1] = source
+ if guide.isSet(source) then
+ sources.sets[#sources.sets+1] = source
+ else
+ sources.gets[#sources.gets+1] = source
+ end
end
function compileLocalID(source)
@@ -154,7 +163,7 @@ end
---@param source parser.object
---@param key? string
---@return parser.object[]?
-function vm.getLocalSources(source, key)
+function vm.getLocalSourcesSets(source, key)
local id = vm.getLocalID(source)
if not id then
return nil
@@ -169,12 +178,34 @@ function vm.getLocalSources(source, key)
end
id = id .. vm.ID_SPLITE .. key
end
- return root._localIDs[id]
+ return root._localIDs[id].sets
end
---@param source parser.object
+---@param key? string
+---@return parser.object[]?
+function vm.getLocalSourcesGets(source, key)
+ local id = vm.getLocalID(source)
+ if not id then
+ return nil
+ end
+ local root = guide.getRoot(source)
+ if not root._localIDs then
+ return nil
+ end
+ if key then
+ if type(key) ~= 'string' then
+ return nil
+ end
+ id = id .. vm.ID_SPLITE .. key
+ end
+ return root._localIDs[id].gets
+end
+
+---@param source parser.object
+---@param includeGets boolean
---@return parser.object[]
-function vm.getLocalFields(source)
+function vm.getLocalFields(source, includeGets)
local id = vm.getLocalID(source)
if not id then
return nil
@@ -192,9 +223,14 @@ function vm.getLocalFields(source)
and lid:sub(#id + 1, #id + 1) == vm.ID_SPLITE
-- only one field
and not lid:find(vm.ID_SPLITE, #id + 2) then
- for _, src in ipairs(sources) do
+ for _, src in ipairs(sources.sets) do
fields[#fields+1] = src
end
+ if includeGets then
+ for _, src in ipairs(sources.gets) do
+ fields[#fields+1] = src
+ end
+ end
end
end
local cost = os.clock() - clock
diff --git a/script/vm/ref.lua b/script/vm/ref.lua
index fbb9d015..031a2e69 100644
--- a/script/vm/ref.lua
+++ b/script/vm/ref.lua
@@ -240,12 +240,17 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
- local idSources = vm.getLocalSources(source)
- if not idSources then
- return
+ local sourceSets = vm.getLocalSourcesSets(source)
+ if sourceSets then
+ for _, src in ipairs(sourceSets) do
+ pushResult(src)
+ end
end
- for _, src in ipairs(idSources) do
- pushResult(src)
+ local sourceGets = vm.getLocalSourcesGets(source)
+ if sourceGets then
+ for _, src in ipairs(sourceGets) do
+ pushResult(src)
+ end
end
end