summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/core/hover/table.lua2
-rw-r--r--script/core/semantic-tokens.lua7
-rw-r--r--script/utility.lua74
-rw-r--r--script/vm/compiler.lua32
-rw-r--r--script/vm/def.lua33
-rw-r--r--script/vm/field.lua45
-rw-r--r--script/vm/global-manager.lua8
-rw-r--r--script/vm/infer.lua7
-rw-r--r--script/vm/init.lua1
-rw-r--r--script/vm/local-id.lua17
-rw-r--r--script/vm/ref.lua24
11 files changed, 138 insertions, 112 deletions
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
index e2c439af..4ad5e552 100644
--- a/script/core/hover/table.lua
+++ b/script/core/hover/table.lua
@@ -154,7 +154,7 @@ return function (source)
return 'table'
end
- local fields = vm.getRefs(source, '*')
+ local fields = vm.getFields(source)
local keys = getKeyMap(fields)
local optMap = getOptionalMap(fields)
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index ec19af1a..7f58014b 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -668,7 +668,6 @@ local Care = util.switch()
type = define.TokenTypes.keyword,
}
end)
- : getMap()
local function buildTokens(uri, results)
local tokens = {}
@@ -805,11 +804,7 @@ return function (uri, start, finish)
local results = {}
guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async
- local method = Care[source.type]
- if not method then
- return
- end
- method(source, options, results)
+ Care(source.type, source, options, results)
await.delay()
end)
diff --git a/script/utility.lua b/script/utility.lua
index 0e4df627..e282f12f 100644
--- a/script/utility.lua
+++ b/script/utility.lua
@@ -673,29 +673,59 @@ function m.arrayToHash(l)
return t
end
-function m.switch()
- local map = {}
- local cachedCases = {}
- local obj = {
- case = function (self, name)
- cachedCases[#cachedCases+1] = name
- return self
- end,
- call = function (self, callback)
- for i = 1, #cachedCases do
- local name = cachedCases[i]
- cachedCases[i] = nil
- if map[name] then
- error('Repeated fields:' .. tostring(name))
- end
- map[name] = callback
- end
- return self
- end,
- getMap = function (self)
- return map
+---@class switch
+---@field cachedCases string[]
+---@field map table<string, function>
+local switchMT = {}
+switchMT.__index = switchMT
+
+---@param name string
+---@return switch
+function switchMT:case(name)
+ self.cachedCases[#self.cachedCases+1] = name
+ return self
+end
+
+---@param callback fun(...):...
+---@return switch
+function switchMT:call(callback)
+ for i = 1, #self.cachedCases do
+ local name = self.cachedCases[i]
+ self.cachedCases[i] = nil
+ if self.map[name] then
+ error('Repeated fields:' .. tostring(name))
end
- }
+ self.map[name] = callback
+ end
+ return self
+end
+
+function switchMT:getMap()
+ return self.map
+end
+
+---@param name string
+---@return boolean
+function switchMT:has(name)
+ return self.map[name] ~= nil
+end
+
+---@param name string
+---@return ...
+function switchMT:__call(name, ...)
+ local callback = self.map[name]
+ if not callback then
+ return
+ end
+ return callback(...)
+end
+
+---@return switch
+function m.switch()
+ local obj = setmetatable({
+ map = {},
+ cachedCases = {},
+ }, switchMT)
return obj
end
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index d6caa71f..d1f8c601 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -15,7 +15,7 @@ local genericMgr = require 'vm.generic'
---@class vm.node.compiler
local m = {}
-local searchFieldMap = util.switch()
+local searchFieldSwitch = util.switch()
: case 'table'
: call(function (node, key, pushResult)
for _, field in ipairs(node) do
@@ -88,7 +88,6 @@ local searchFieldMap = util.switch()
end
end
end)
- : getMap()
function m.getClassFields(node, key, pushResult)
@@ -112,16 +111,14 @@ function m.getClassFields(node, key, pushResult)
-- check local field and global field
if set.bindSources then
for _, src in ipairs(set.bindSources) do
- if searchFieldMap[src.type] then
- searchFieldMap[src.type](src, key, function (field)
- if guide.isSet(field) then
- hasFounded = true
- pushResult(field)
- end
- end)
- end
+ searchFieldSwitch(src.type, src, key, function (field)
+ if guide.isSet(field) then
+ hasFounded = true
+ pushResult(field)
+ end
+ end)
if src._globalNode then
- searchFieldMap['global'](src._globalNode, key, function (field)
+ searchFieldSwitch('global', src._globalNode, key, function (field)
hasFounded = true
pushResult(field)
end)
@@ -352,10 +349,7 @@ function m.compileByParentNode(source, key, pushResult)
return
end
for node in nodeMgr.eachNode(parentNode) do
- local f = searchFieldMap[node.type]
- if f then
- f(node, key, pushResult)
- end
+ searchFieldSwitch(node.type, node, key, pushResult)
end
end
@@ -466,7 +460,7 @@ local function setCallArgNode(source, call, callNode, fixIndex)
end
end
-local compilerMap = util.switch()
+local compilerSwitch = util.switch()
: case 'nil'
: case 'boolean'
: case 'integer'
@@ -1165,14 +1159,10 @@ local compilerMap = util.switch()
end
end
end)
- : getMap()
---@param source parser.object
local function compileByNode(source)
- local compiler = compilerMap[source.type]
- if compiler then
- compiler(source)
- end
+ compilerSwitch(source.type, source)
end
---@param source vm.node
diff --git a/script/vm/def.lua b/script/vm/def.lua
index 22c52be1..98d41fbd 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -7,7 +7,7 @@ local localID = require 'vm.local-id'
local globalMgr = require 'vm.global-manager'
local nodeMgr = require 'vm.node'
-local simpleMap
+local simpleSwitch
local function searchGetLocal(source, node, pushResult)
local key = guide.getKeyName(source)
@@ -21,7 +21,7 @@ local function searchGetLocal(source, node, pushResult)
end
end
-simpleMap = util.switch()
+simpleSwitch = util.switch()
: case 'local'
: call(function (source, pushResult)
pushResult(source)
@@ -42,13 +42,13 @@ simpleMap = util.switch()
: case 'getlocal'
: case 'setlocal'
: call(function (source, pushResult)
- simpleMap['local'](source.node, pushResult)
+ simpleSwitch('local', source.node, pushResult)
end)
: case 'field'
: call(function (source, pushResult)
local parent = source.parent
if parent.type ~= 'tablefield' then
- simpleMap[parent.type](parent, pushResult)
+ simpleSwitch(parent.type, parent, pushResult)
end
end)
: case 'setfield'
@@ -74,9 +74,8 @@ simpleMap = util.switch()
pushResult(source.node)
end
end)
- : getMap()
-local searchFieldMap = util.switch()
+local searchFieldSwitch = util.switch()
: case 'table'
: call(function (node, key, pushResult)
for _, field in ipairs(node) do
@@ -126,10 +125,9 @@ local searchFieldMap = util.switch()
end
end
end)
- : getMap()
local searchByParentNode
-local nodeMap = util.switch()
+local nodeSwitch = util.switch()
: case 'field'
: case 'method'
: call(function (source, pushResult)
@@ -148,9 +146,7 @@ local nodeMap = util.switch()
end
local key = guide.getKeyName(source)
for pn in nodeMgr.eachNode(parentNode) do
- if searchFieldMap[pn.type] then
- searchFieldMap[pn.type](pn, key, pushResult)
- end
+ searchFieldSwitch(pn.type, pn, key, pushResult)
end
end)
: case 'doc.see.field'
@@ -160,20 +156,14 @@ local nodeMap = util.switch()
return
end
for pn in nodeMgr.eachNode(parentNode) do
- if searchFieldMap[pn.type] then
- searchFieldMap[pn.type](pn, source[1], pushResult)
- end
+ searchFieldSwitch(pn.type, pn, source[1], pushResult)
end
end)
- : getMap()
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchBySimple(source, pushResult)
- local simple = simpleMap[source.type]
- if simple then
- simple(source, pushResult)
- end
+ simpleSwitch(source.type, source, pushResult)
end
---@param source parser.object
@@ -193,10 +183,7 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
function searchByParentNode(source, pushResult)
- local node = nodeMap[source.type]
- if node then
- node(source, pushResult)
- end
+ nodeSwitch(source.type, source, pushResult)
end
local function searchByNode(source, pushResult)
diff --git a/script/vm/field.lua b/script/vm/field.lua
new file mode 100644
index 00000000..92448bb3
--- /dev/null
+++ b/script/vm/field.lua
@@ -0,0 +1,45 @@
+---@class vm
+local vm = require 'vm.vm'
+local util = require 'utility'
+local compiler = require 'vm.compiler'
+local guide = require 'parser.guide'
+local localID = require 'vm.local-id'
+local globalMgr = require 'vm.global-manager'
+local nodeMgr = require 'vm.node'
+
+local searchNodeSwitch = util.switch()
+ : case 'table'
+ : call(function (node, pushResult)
+ for _, field in ipairs(node) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ pushResult(field)
+ end
+ end
+ end)
+
+local function searchByNode(source, pushResult)
+ local node = compiler.compileNode(source)
+ if not node then
+ return
+ end
+ searchNodeSwitch(node.type, node, pushResult)
+end
+
+---@param source parser.object
+---@return parser.object[]
+function vm.getFields(source)
+ local results = {}
+ local mark = {}
+
+ local function pushResult(src)
+ if not mark[src] then
+ mark[src] = true
+ results[#results+1] = src
+ end
+ end
+
+ searchByNode(source, pushResult)
+
+ return results
+end
diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua
index f78faf7d..99d1b697 100644
--- a/script/vm/global-manager.lua
+++ b/script/vm/global-manager.lua
@@ -16,7 +16,7 @@ m.globalSubs = util.multiTable(2)
m.ID_SPLITE = '\x1F'
-local compilerGlobalMap = util.switch()
+local compilerGlobalSwitch = util.switch()
: case 'local'
: call(function (source)
if source.special ~= '_G' then
@@ -195,7 +195,6 @@ local compilerGlobalMap = util.switch()
class:addGet(uri, source)
source._globalNode = class
end)
- : getMap()
---@alias vm.global.cate '"variable"' | '"type"'
@@ -231,10 +230,7 @@ function m.compileObject(source)
return
end
source._globalNode = false
- local compiler = compilerGlobalMap[source.type]
- if compiler then
- compiler(source)
- end
+ compilerGlobalSwitch(source.type, source)
end
---@param source parser.object
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 5ac7d73b..6457696a 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -17,7 +17,7 @@ local inferSorted = {
['nil'] = 100,
}
-local viewNodeMap = util.switch()
+local viewNodeSwitch = util.switch()
: case 'nil'
: case 'boolean'
: case 'string'
@@ -112,14 +112,11 @@ local viewNodeMap = util.switch()
end
return ('fun(%s)%s'):format(argView, regView)
end)
- : getMap()
---@param node vm.node
---@return string?
local function viewNode(node, options)
- if viewNodeMap[node.type] then
- return viewNodeMap[node.type](node, options)
- end
+ return viewNodeSwitch(node.type, node, options)
end
local function eraseAlias(node, viewMap, options)
diff --git a/script/vm/init.lua b/script/vm/init.lua
index ca6b8606..b5d37136 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -2,6 +2,7 @@ local vm = require 'vm.vm'
require 'vm.manager'
require 'vm.def'
require 'vm.ref'
+require 'vm.field'
require 'vm.getDocs'
require 'vm.getLibrary'
require 'vm.getLinks'
diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua
index 3b266916..4c9da197 100644
--- a/script/vm/local-id.lua
+++ b/script/vm/local-id.lua
@@ -10,7 +10,7 @@ local m = {}
m.ID_SPLITE = '\x1F'
-local compileMap = util.switch()
+local compileSwitch = util.switch()
: case 'local'
: call(function (source)
source._localID = ('%d'):format(source.start)
@@ -69,9 +69,8 @@ local compileMap = util.switch()
m.compileLocalID(source.next)
end
end)
- : getMap()
-local leftMap = util.switch()
+local leftSwitch = util.switch()
: case 'field'
: case 'method'
: call(function (source)
@@ -94,16 +93,11 @@ local leftMap = util.switch()
: call(function (source)
return source
end)
- : getMap()
---@param source parser.object
---@return parser.object?
function m.getLocal(source)
- local getLeft = leftMap[source.type]
- if getLeft then
- return getLeft(source)
- end
- return nil
+ return leftSwitch(source.type, source)
end
function m.compileLocalID(source)
@@ -111,11 +105,10 @@ function m.compileLocalID(source)
return
end
source._localID = false
- local compiler = compileMap[source.type]
- if not compiler then
+ if not compileSwitch:has(source.type) then
return
end
- compiler(source)
+ compileSwitch(source.type, source)
local root = guide.getRoot(source)
if not root._localIDs then
root._localIDs = util.multiTable(2)
diff --git a/script/vm/ref.lua b/script/vm/ref.lua
index 4765dc8d..b679ae4a 100644
--- a/script/vm/ref.lua
+++ b/script/vm/ref.lua
@@ -8,7 +8,7 @@ local globalMgr = require 'vm.global-manager'
local nodeMgr = require 'vm.node'
local files = require 'files'
-local simpleMap
+local simpleSwitch
local function searchGetLocal(source, node, pushResult)
local key = guide.getKeyName(source)
@@ -21,7 +21,7 @@ local function searchGetLocal(source, node, pushResult)
end
end
-simpleMap = util.switch()
+simpleSwitch = util.switch()
: case 'local'
: call(function (source, pushResult)
if source.ref then
@@ -36,13 +36,13 @@ simpleMap = util.switch()
: case 'getlocal'
: case 'setlocal'
: call(function (source, pushResult)
- simpleMap['local'](source.node, pushResult)
+ simpleSwitch('local', source.node, pushResult)
end)
: case 'field'
: call(function (source, pushResult)
local parent = source.parent
if parent.type ~= 'tablefield' then
- simpleMap[parent.type](parent, pushResult)
+ simpleSwitch(parent.type, parent, pushResult)
end
end)
: case 'setfield'
@@ -65,7 +65,7 @@ simpleMap = util.switch()
: case 'goto'
: call(function (source, pushResult)
if source.node then
- simpleMap['label'](source.node, pushResult)
+ simpleSwitch('label', source.node, pushResult)
pushResult(source.node)
end
end)
@@ -78,7 +78,6 @@ simpleMap = util.switch()
end
end
end)
- : getMap()
local function searchField(source, pushResult)
local key = guide.getKeyName(source)
@@ -132,7 +131,7 @@ local function searchField(source, pushResult)
end
local searchByParentNode
-local nodeMap = util.switch()
+local nodeSwitch = util.switch()
: case 'field'
: case 'method'
: call(function (source, pushResult)
@@ -162,15 +161,11 @@ local nodeMap = util.switch()
: call(function (source, pushResult)
searchField(source, pushResult)
end)
- : getMap()
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchBySimple(source, pushResult)
- local simple = simpleMap[source.type]
- if simple then
- simple(source, pushResult)
- end
+ simpleSwitch(source.type, source, pushResult)
end
---@param source parser.object
@@ -190,10 +185,7 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
function searchByParentNode(source, pushResult)
- local node = nodeMap[source.type]
- if node then
- node(source, pushResult)
- end
+ nodeSwitch(source.type, source, pushResult)
end
local function searchByNode(source, pushResult)