diff options
Diffstat (limited to 'script')
-rw-r--r-- | script/core/hover/table.lua | 2 | ||||
-rw-r--r-- | script/core/semantic-tokens.lua | 7 | ||||
-rw-r--r-- | script/utility.lua | 74 | ||||
-rw-r--r-- | script/vm/compiler.lua | 32 | ||||
-rw-r--r-- | script/vm/def.lua | 33 | ||||
-rw-r--r-- | script/vm/field.lua | 45 | ||||
-rw-r--r-- | script/vm/global-manager.lua | 8 | ||||
-rw-r--r-- | script/vm/infer.lua | 7 | ||||
-rw-r--r-- | script/vm/init.lua | 1 | ||||
-rw-r--r-- | script/vm/local-id.lua | 17 | ||||
-rw-r--r-- | script/vm/ref.lua | 24 |
11 files changed, 138 insertions, 112 deletions
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index e2c439af..4ad5e552 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -154,7 +154,7 @@ return function (source) return 'table' end - local fields = vm.getRefs(source, '*') + local fields = vm.getFields(source) local keys = getKeyMap(fields) local optMap = getOptionalMap(fields) diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index ec19af1a..7f58014b 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -668,7 +668,6 @@ local Care = util.switch() type = define.TokenTypes.keyword, } end) - : getMap() local function buildTokens(uri, results) local tokens = {} @@ -805,11 +804,7 @@ return function (uri, start, finish) local results = {} guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async - local method = Care[source.type] - if not method then - return - end - method(source, options, results) + Care(source.type, source, options, results) await.delay() end) diff --git a/script/utility.lua b/script/utility.lua index 0e4df627..e282f12f 100644 --- a/script/utility.lua +++ b/script/utility.lua @@ -673,29 +673,59 @@ function m.arrayToHash(l) return t end -function m.switch() - local map = {} - local cachedCases = {} - local obj = { - case = function (self, name) - cachedCases[#cachedCases+1] = name - return self - end, - call = function (self, callback) - for i = 1, #cachedCases do - local name = cachedCases[i] - cachedCases[i] = nil - if map[name] then - error('Repeated fields:' .. tostring(name)) - end - map[name] = callback - end - return self - end, - getMap = function (self) - return map +---@class switch +---@field cachedCases string[] +---@field map table<string, function> +local switchMT = {} +switchMT.__index = switchMT + +---@param name string +---@return switch +function switchMT:case(name) + self.cachedCases[#self.cachedCases+1] = name + return self +end + +---@param callback fun(...):... +---@return switch +function switchMT:call(callback) + for i = 1, #self.cachedCases do + local name = self.cachedCases[i] + self.cachedCases[i] = nil + if self.map[name] then + error('Repeated fields:' .. tostring(name)) end - } + self.map[name] = callback + end + return self +end + +function switchMT:getMap() + return self.map +end + +---@param name string +---@return boolean +function switchMT:has(name) + return self.map[name] ~= nil +end + +---@param name string +---@return ... +function switchMT:__call(name, ...) + local callback = self.map[name] + if not callback then + return + end + return callback(...) +end + +---@return switch +function m.switch() + local obj = setmetatable({ + map = {}, + cachedCases = {}, + }, switchMT) return obj end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index d6caa71f..d1f8c601 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -15,7 +15,7 @@ local genericMgr = require 'vm.generic' ---@class vm.node.compiler local m = {} -local searchFieldMap = util.switch() +local searchFieldSwitch = util.switch() : case 'table' : call(function (node, key, pushResult) for _, field in ipairs(node) do @@ -88,7 +88,6 @@ local searchFieldMap = util.switch() end end end) - : getMap() function m.getClassFields(node, key, pushResult) @@ -112,16 +111,14 @@ function m.getClassFields(node, key, pushResult) -- check local field and global field if set.bindSources then for _, src in ipairs(set.bindSources) do - if searchFieldMap[src.type] then - searchFieldMap[src.type](src, key, function (field) - if guide.isSet(field) then - hasFounded = true - pushResult(field) - end - end) - end + searchFieldSwitch(src.type, src, key, function (field) + if guide.isSet(field) then + hasFounded = true + pushResult(field) + end + end) if src._globalNode then - searchFieldMap['global'](src._globalNode, key, function (field) + searchFieldSwitch('global', src._globalNode, key, function (field) hasFounded = true pushResult(field) end) @@ -352,10 +349,7 @@ function m.compileByParentNode(source, key, pushResult) return end for node in nodeMgr.eachNode(parentNode) do - local f = searchFieldMap[node.type] - if f then - f(node, key, pushResult) - end + searchFieldSwitch(node.type, node, key, pushResult) end end @@ -466,7 +460,7 @@ local function setCallArgNode(source, call, callNode, fixIndex) end end -local compilerMap = util.switch() +local compilerSwitch = util.switch() : case 'nil' : case 'boolean' : case 'integer' @@ -1165,14 +1159,10 @@ local compilerMap = util.switch() end end end) - : getMap() ---@param source parser.object local function compileByNode(source) - local compiler = compilerMap[source.type] - if compiler then - compiler(source) - end + compilerSwitch(source.type, source) end ---@param source vm.node diff --git a/script/vm/def.lua b/script/vm/def.lua index 22c52be1..98d41fbd 100644 --- a/script/vm/def.lua +++ b/script/vm/def.lua @@ -7,7 +7,7 @@ local localID = require 'vm.local-id' local globalMgr = require 'vm.global-manager' local nodeMgr = require 'vm.node' -local simpleMap +local simpleSwitch local function searchGetLocal(source, node, pushResult) local key = guide.getKeyName(source) @@ -21,7 +21,7 @@ local function searchGetLocal(source, node, pushResult) end end -simpleMap = util.switch() +simpleSwitch = util.switch() : case 'local' : call(function (source, pushResult) pushResult(source) @@ -42,13 +42,13 @@ simpleMap = util.switch() : case 'getlocal' : case 'setlocal' : call(function (source, pushResult) - simpleMap['local'](source.node, pushResult) + simpleSwitch('local', source.node, pushResult) end) : case 'field' : call(function (source, pushResult) local parent = source.parent if parent.type ~= 'tablefield' then - simpleMap[parent.type](parent, pushResult) + simpleSwitch(parent.type, parent, pushResult) end end) : case 'setfield' @@ -74,9 +74,8 @@ simpleMap = util.switch() pushResult(source.node) end end) - : getMap() -local searchFieldMap = util.switch() +local searchFieldSwitch = util.switch() : case 'table' : call(function (node, key, pushResult) for _, field in ipairs(node) do @@ -126,10 +125,9 @@ local searchFieldMap = util.switch() end end end) - : getMap() local searchByParentNode -local nodeMap = util.switch() +local nodeSwitch = util.switch() : case 'field' : case 'method' : call(function (source, pushResult) @@ -148,9 +146,7 @@ local nodeMap = util.switch() end local key = guide.getKeyName(source) for pn in nodeMgr.eachNode(parentNode) do - if searchFieldMap[pn.type] then - searchFieldMap[pn.type](pn, key, pushResult) - end + searchFieldSwitch(pn.type, pn, key, pushResult) end end) : case 'doc.see.field' @@ -160,20 +156,14 @@ local nodeMap = util.switch() return end for pn in nodeMgr.eachNode(parentNode) do - if searchFieldMap[pn.type] then - searchFieldMap[pn.type](pn, source[1], pushResult) - end + searchFieldSwitch(pn.type, pn, source[1], pushResult) end end) - : getMap() ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchBySimple(source, pushResult) - local simple = simpleMap[source.type] - if simple then - simple(source, pushResult) - end + simpleSwitch(source.type, source, pushResult) end ---@param source parser.object @@ -193,10 +183,7 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) function searchByParentNode(source, pushResult) - local node = nodeMap[source.type] - if node then - node(source, pushResult) - end + nodeSwitch(source.type, source, pushResult) end local function searchByNode(source, pushResult) diff --git a/script/vm/field.lua b/script/vm/field.lua new file mode 100644 index 00000000..92448bb3 --- /dev/null +++ b/script/vm/field.lua @@ -0,0 +1,45 @@ +---@class vm +local vm = require 'vm.vm' +local util = require 'utility' +local compiler = require 'vm.compiler' +local guide = require 'parser.guide' +local localID = require 'vm.local-id' +local globalMgr = require 'vm.global-manager' +local nodeMgr = require 'vm.node' + +local searchNodeSwitch = util.switch() + : case 'table' + : call(function (node, pushResult) + for _, field in ipairs(node) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + pushResult(field) + end + end + end) + +local function searchByNode(source, pushResult) + local node = compiler.compileNode(source) + if not node then + return + end + searchNodeSwitch(node.type, node, pushResult) +end + +---@param source parser.object +---@return parser.object[] +function vm.getFields(source) + local results = {} + local mark = {} + + local function pushResult(src) + if not mark[src] then + mark[src] = true + results[#results+1] = src + end + end + + searchByNode(source, pushResult) + + return results +end diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua index f78faf7d..99d1b697 100644 --- a/script/vm/global-manager.lua +++ b/script/vm/global-manager.lua @@ -16,7 +16,7 @@ m.globalSubs = util.multiTable(2) m.ID_SPLITE = '\x1F' -local compilerGlobalMap = util.switch() +local compilerGlobalSwitch = util.switch() : case 'local' : call(function (source) if source.special ~= '_G' then @@ -195,7 +195,6 @@ local compilerGlobalMap = util.switch() class:addGet(uri, source) source._globalNode = class end) - : getMap() ---@alias vm.global.cate '"variable"' | '"type"' @@ -231,10 +230,7 @@ function m.compileObject(source) return end source._globalNode = false - local compiler = compilerGlobalMap[source.type] - if compiler then - compiler(source) - end + compilerGlobalSwitch(source.type, source) end ---@param source parser.object diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 5ac7d73b..6457696a 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -17,7 +17,7 @@ local inferSorted = { ['nil'] = 100, } -local viewNodeMap = util.switch() +local viewNodeSwitch = util.switch() : case 'nil' : case 'boolean' : case 'string' @@ -112,14 +112,11 @@ local viewNodeMap = util.switch() end return ('fun(%s)%s'):format(argView, regView) end) - : getMap() ---@param node vm.node ---@return string? local function viewNode(node, options) - if viewNodeMap[node.type] then - return viewNodeMap[node.type](node, options) - end + return viewNodeSwitch(node.type, node, options) end local function eraseAlias(node, viewMap, options) diff --git a/script/vm/init.lua b/script/vm/init.lua index ca6b8606..b5d37136 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -2,6 +2,7 @@ local vm = require 'vm.vm' require 'vm.manager' require 'vm.def' require 'vm.ref' +require 'vm.field' require 'vm.getDocs' require 'vm.getLibrary' require 'vm.getLinks' diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua index 3b266916..4c9da197 100644 --- a/script/vm/local-id.lua +++ b/script/vm/local-id.lua @@ -10,7 +10,7 @@ local m = {} m.ID_SPLITE = '\x1F' -local compileMap = util.switch() +local compileSwitch = util.switch() : case 'local' : call(function (source) source._localID = ('%d'):format(source.start) @@ -69,9 +69,8 @@ local compileMap = util.switch() m.compileLocalID(source.next) end end) - : getMap() -local leftMap = util.switch() +local leftSwitch = util.switch() : case 'field' : case 'method' : call(function (source) @@ -94,16 +93,11 @@ local leftMap = util.switch() : call(function (source) return source end) - : getMap() ---@param source parser.object ---@return parser.object? function m.getLocal(source) - local getLeft = leftMap[source.type] - if getLeft then - return getLeft(source) - end - return nil + return leftSwitch(source.type, source) end function m.compileLocalID(source) @@ -111,11 +105,10 @@ function m.compileLocalID(source) return end source._localID = false - local compiler = compileMap[source.type] - if not compiler then + if not compileSwitch:has(source.type) then return end - compiler(source) + compileSwitch(source.type, source) local root = guide.getRoot(source) if not root._localIDs then root._localIDs = util.multiTable(2) diff --git a/script/vm/ref.lua b/script/vm/ref.lua index 4765dc8d..b679ae4a 100644 --- a/script/vm/ref.lua +++ b/script/vm/ref.lua @@ -8,7 +8,7 @@ local globalMgr = require 'vm.global-manager' local nodeMgr = require 'vm.node' local files = require 'files' -local simpleMap +local simpleSwitch local function searchGetLocal(source, node, pushResult) local key = guide.getKeyName(source) @@ -21,7 +21,7 @@ local function searchGetLocal(source, node, pushResult) end end -simpleMap = util.switch() +simpleSwitch = util.switch() : case 'local' : call(function (source, pushResult) if source.ref then @@ -36,13 +36,13 @@ simpleMap = util.switch() : case 'getlocal' : case 'setlocal' : call(function (source, pushResult) - simpleMap['local'](source.node, pushResult) + simpleSwitch('local', source.node, pushResult) end) : case 'field' : call(function (source, pushResult) local parent = source.parent if parent.type ~= 'tablefield' then - simpleMap[parent.type](parent, pushResult) + simpleSwitch(parent.type, parent, pushResult) end end) : case 'setfield' @@ -65,7 +65,7 @@ simpleMap = util.switch() : case 'goto' : call(function (source, pushResult) if source.node then - simpleMap['label'](source.node, pushResult) + simpleSwitch('label', source.node, pushResult) pushResult(source.node) end end) @@ -78,7 +78,6 @@ simpleMap = util.switch() end end end) - : getMap() local function searchField(source, pushResult) local key = guide.getKeyName(source) @@ -132,7 +131,7 @@ local function searchField(source, pushResult) end local searchByParentNode -local nodeMap = util.switch() +local nodeSwitch = util.switch() : case 'field' : case 'method' : call(function (source, pushResult) @@ -162,15 +161,11 @@ local nodeMap = util.switch() : call(function (source, pushResult) searchField(source, pushResult) end) - : getMap() ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchBySimple(source, pushResult) - local simple = simpleMap[source.type] - if simple then - simple(source, pushResult) - end + simpleSwitch(source.type, source, pushResult) end ---@param source parser.object @@ -190,10 +185,7 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) function searchByParentNode(source, pushResult) - local node = nodeMap[source.type] - if node then - node(source, pushResult) - end + nodeSwitch(source.type, source, pushResult) end local function searchByNode(source, pushResult) |