summaryrefslogtreecommitdiff
path: root/script/vm
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm')
-rw-r--r--script/vm/compiler.lua459
-rw-r--r--script/vm/def.lua15
-rw-r--r--script/vm/doc.lua11
-rw-r--r--script/vm/field.lua10
-rw-r--r--script/vm/generic.lua5
-rw-r--r--script/vm/global-manager.lua364
-rw-r--r--script/vm/global.lua431
-rw-r--r--script/vm/infer.lua115
-rw-r--r--script/vm/init.lua10
-rw-r--r--script/vm/library.lua21
-rw-r--r--script/vm/local-id.lua62
-rw-r--r--script/vm/local-manager.lua40
-rw-r--r--script/vm/manager.lua26
-rw-r--r--script/vm/node.lua260
-rw-r--r--script/vm/ref.lua6
-rw-r--r--script/vm/runner.lua444
-rw-r--r--script/vm/sign.lua29
-rw-r--r--script/vm/type.lua11
-rw-r--r--script/vm/value.lua30
-rw-r--r--script/vm/vm.lua1
20 files changed, 1511 insertions, 839 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 8126f393..75620d19 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1,10 +1,6 @@
local guide = require 'parser.guide'
local util = require 'utility'
-local localID = require 'vm.local-id'
-local globalMgr = require 'vm.global-manager'
-local signMgr = require 'vm.sign'
local config = require 'config'
-local genericMgr = require 'vm.generic'
local rpath = require 'workspace.require-path'
local files = require 'files'
---@class vm
@@ -13,7 +9,6 @@ local vm = require 'vm.vm'
---@class parser.object
---@field _compiledNodes boolean
---@field _node vm.node
----@field _localBase table
---@field _globalBase table
local searchFieldSwitch = util.switch()
@@ -54,7 +49,7 @@ local searchFieldSwitch = util.switch()
: case 'string'
: call(function (suri, source, key, ref, pushResult)
-- change to `string: stringlib` ?
- local stringlib = globalMgr.getGlobal('type', 'stringlib')
+ local stringlib = vm.getGlobal('type', 'stringlib')
if stringlib then
vm.getClassFields(suri, stringlib, key, ref, pushResult)
end
@@ -64,9 +59,9 @@ local searchFieldSwitch = util.switch()
: call(function (suri, node, key, ref, pushResult)
local fields
if key then
- fields = localID.getSources(node, key)
+ fields = vm.getLocalSources(node, key)
else
- fields = localID.getFields(node)
+ fields = vm.getLocalFields(node)
end
if fields then
for _, src in ipairs(fields) do
@@ -119,7 +114,7 @@ local searchFieldSwitch = util.switch()
if type(key) ~= 'string' then
return
end
- local global = globalMgr.getGlobal('variable', node.name, key)
+ local global = vm.getGlobal('variable', node.name, key)
if global then
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -131,7 +126,7 @@ local searchFieldSwitch = util.switch()
end
end
else
- local globals = globalMgr.getFields('variable', node.name)
+ local globals = vm.getGlobalFields('variable', node.name)
for _, global in ipairs(globals) do
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -158,7 +153,7 @@ local searchFieldSwitch = util.switch()
if type(key) ~= 'string' then
return
end
- local global = globalMgr.getGlobal('variable', node.name, key)
+ local global = vm.getGlobal('variable', node.name, key)
if global then
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -168,7 +163,7 @@ local searchFieldSwitch = util.switch()
end
end
else
- local globals = globalMgr.getFields('variable', node.name)
+ local globals = vm.getGlobalFields('variable', node.name)
for _, global in ipairs(globals) do
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -185,7 +180,7 @@ local searchFieldSwitch = util.switch()
end)
-function vm.getClassFields(suri, node, key, ref, pushResult)
+function vm.getClassFields(suri, object, key, ref, pushResult)
local mark = {}
local function searchClass(class, searchedFields)
@@ -201,11 +196,51 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
local hasFounded = {}
for _, field in ipairs(set.fields) do
local fieldKey = guide.getKeyName(field)
- if key == nil
- or fieldKey == key then
- if not searchedFields[fieldKey] then
- pushResult(field)
- hasFounded[fieldKey] = true
+ if fieldKey then
+ -- ---@field x boolean -> class.x
+ if key == nil
+ or fieldKey == key then
+ if not searchedFields[fieldKey] then
+ pushResult(field)
+ hasFounded[fieldKey] = true
+ end
+ end
+ end
+ if not hasFounded[fieldKey] then
+ local keyType = type(key)
+ if keyType == 'table' then
+ -- ---@field [integer] boolean -> class[integer]
+ local fieldNode = vm.compileNode(field.field)
+ if vm.isSubType(suri, key.name, fieldNode) then
+ local nkey = '|' .. key.name
+ if not searchedFields[nkey] then
+ pushResult(field)
+ hasFounded[nkey] = true
+ end
+ end
+ else
+ local typeName
+ if keyType == 'number' then
+ if math.tointeger(key) then
+ typeName = 'integer'
+ else
+ typeName = 'number'
+ end
+ elseif keyType == 'boolean'
+ or keyType == 'string' then
+ typeName = keyType
+ end
+ if typeName then
+ -- ---@field [integer] boolean -> class[1]
+ local fieldNode = vm.compileNode(field.field)
+ if vm.isSubType(suri, typeName, fieldNode) then
+ local nkey = '|' .. typeName
+ if not searchedFields[nkey] then
+ pushResult(field)
+ hasFounded[nkey] = true
+ end
+ end
+ end
end
end
end
@@ -214,19 +249,23 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
for _, src in ipairs(set.bindSources) do
searchFieldSwitch(src.type, suri, src, key, ref, function (field)
local fieldKey = guide.getKeyName(field)
- if not searchedFields[fieldKey]
- and guide.isSet(field) then
- hasFounded[fieldKey] = true
- pushResult(field)
+ if fieldKey then
+ if not searchedFields[fieldKey]
+ and guide.isSet(field) then
+ hasFounded[fieldKey] = true
+ pushResult(field)
+ end
end
end)
if src.value and src.value.type == 'table' then
searchFieldSwitch('table', suri, src.value, key, ref, function (field)
local fieldKey = guide.getKeyName(field)
- if not searchedFields[fieldKey]
- and guide.isSet(field) then
- hasFounded[fieldKey] = true
- pushResult(field)
+ if fieldKey then
+ if not searchedFields[fieldKey]
+ and guide.isSet(field) then
+ hasFounded[fieldKey] = true
+ pushResult(field)
+ end
end
end)
end
@@ -239,7 +278,7 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
end
for _, extend in ipairs(set.extends) do
if extend.type == 'doc.extends.name' then
- local extendType = globalMgr.getGlobal('type', extend[1])
+ local extendType = vm.getGlobal('type', extend[1])
if extendType then
searchClass(extendType, searchedFields)
end
@@ -253,12 +292,12 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
local function searchGlobal(class)
if class.cate == 'type' and class.name == '_G' then
if key == nil then
- local sets = globalMgr.getGlobalSets(suri, 'variable')
+ local sets = vm.getGlobalSets(suri, 'variable')
for _, set in ipairs(sets) do
pushResult(set)
end
else
- local global = globalMgr.getGlobal('variable', key)
+ local global = vm.getGlobal('variable', key)
if global then
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
@@ -268,8 +307,8 @@ function vm.getClassFields(suri, node, key, ref, pushResult)
end
end
- searchClass(node)
- searchGlobal(node)
+ searchClass(object)
+ searchGlobal(object)
end
---@class parser.object
@@ -283,10 +322,13 @@ local function getObjectSign(source)
end
source._sign = false
if source.type == 'function' then
+ if not source.bindDocs then
+ return false
+ end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.generic' then
if not source._sign then
- source._sign = signMgr()
+ source._sign = vm.createSign()
break
end
end
@@ -314,14 +356,18 @@ local function getObjectSign(source)
if not hasGeneric then
return false
end
- source._sign = signMgr()
+ source._sign = vm.createSign()
if source.type == 'doc.type.function' then
for _, arg in ipairs(source.args) do
- local argNode = vm.compileNode(arg.extends)
- if arg.optional then
- argNode:addOptional()
+ if arg.extends then
+ local argNode = vm.compileNode(arg.extends)
+ if arg.optional then
+ argNode:addOptional()
+ end
+ source._sign:addSign(argNode)
+ else
+ source._sign:addSign(vm.createNode())
end
- source._sign:addSign(argNode)
end
end
end
@@ -354,7 +400,7 @@ function vm.getReturnOfFunction(func, index)
if not sign then
return rtn
end
- return genericMgr(rtn, sign)
+ return vm.createGeneric(rtn, sign)
end
end
@@ -455,6 +501,9 @@ local function getReturn(func, index, args)
result:merge(rnode)
end
end
+ if result and returnNode:isOptional() then
+ result:addOptional()
+ end
end
end
end
@@ -462,6 +511,25 @@ local function getReturn(func, index, args)
return result
end
+---@param source parser.object
+---@return boolean
+local function bindAs(source)
+ local root = guide.getRoot(source)
+ local docs = root.docs
+ if not docs then
+ return
+ end
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.as' and doc.originalComment.start == source.finish + 2 then
+ if doc.as then
+ vm.setNode(source, vm.compileNode(doc.as), true)
+ end
+ return true
+ end
+ end
+ return false
+end
+
local function bindDocs(source)
local isParam = source.parent.type == 'funcargs'
or source.parent.type == 'in'
@@ -485,7 +553,11 @@ local function bindDocs(source)
end
if doc.type == 'doc.param' then
if isParam and source[1] == doc.param[1] then
- vm.setNode(source, vm.compileNode(doc))
+ local node = vm.compileNode(doc)
+ if doc.optional then
+ node:addOptional()
+ end
+ vm.setNode(source, node)
return true
end
end
@@ -503,12 +575,17 @@ local function bindDocs(source)
vm.setNode(source, vm.compileNode(ast))
return true
end
+ if doc.type == 'doc.overload' then
+ if not isParam then
+ vm.setNode(source, vm.compileNode(doc))
+ end
+ end
end
return false
end
local function compileByLocalID(source)
- local sources = localID.getSources(source)
+ local sources = vm.getLocalSources(source)
if not sources then
return
end
@@ -571,7 +648,7 @@ local function selectNode(source, list, index)
if exp.type == 'call' then
result = getReturn(exp.node, index, exp.args)
if not result then
- vm.setNode(source, globalMgr.getGlobal('type', 'unknown'))
+ vm.setNode(source, vm.declareGlobal('type', 'unknown'))
return vm.getNode(source)
end
else
@@ -597,7 +674,7 @@ local function selectNode(source, list, index)
end
end
if not hasKnownType then
- rtnNode:merge(globalMgr.getGlobal('type', 'unknown'))
+ rtnNode:merge(vm.declareGlobal('type', 'unknown'))
end
vm.setNode(source, rtnNode)
return rtnNode
@@ -664,10 +741,21 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
for n in callNode:eachObject() do
if n.type == 'function' then
+ local sign = getObjectSign(n)
local farg = getFuncArg(n, myIndex)
if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
+ if fn.type == 'doc.type.function' then
+ if sign then
+ local generic = vm.createGeneric(fn, sign)
+ local args = {}
+ for i = fixIndex + 1, myIndex - 1 do
+ args[#args+1] = call.args[i]
+ end
+ fn = generic:resolve(guide.getUri(call), args)
+ end
+ end
vm.setNode(arg, fn)
end
end
@@ -716,29 +804,19 @@ function vm.compileCallArg(arg, call, index)
if call.node.special == 'pcall'
or call.node.special == 'xpcall' then
local fixIndex = call.node.special == 'pcall' and 1 or 2
- callNode = vm.compileNode(call.args[1])
- compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex)
+ if call.args and call.args[1] then
+ callNode = vm.compileNode(call.args[1])
+ compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex)
+ end
end
return vm.getNode(arg)
end
---@param source parser.object
---@return vm.node
-local function compileLocalBase(source)
- if not source._localBase then
- source._localBase = {
- type = 'localbase',
- parent = source,
- }
- end
- local baseNode = vm.getNode(source._localBase)
- if baseNode then
- return baseNode
- end
- baseNode = vm.createNode()
- vm.setNode(source._localBase, baseNode, true)
-
+local function compileLocal(source)
vm.setNode(source, source)
+
local hasMarkDoc
if source.bindDocs then
hasMarkDoc = bindDocs(source)
@@ -788,14 +866,19 @@ local function compileLocalBase(source)
if n.type == 'doc.type.function' then
for index, arg in ipairs(n.args) do
if func.args[index] == source then
- vm.setNode(source, vm.compileNode(arg))
+ local argNode = vm.compileNode(arg)
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
+ end
+ end
hasDocArg = true
end
end
end
end
if not hasDocArg then
- vm.setNode(source, globalMgr.getGlobal('type', 'any'))
+ vm.setNode(source, vm.declareGlobal('type', 'any'))
end
end
-- for x in ... do
@@ -805,15 +888,10 @@ local function compileLocalBase(source)
-- for x = ... do
if source.parent.type == 'loop' then
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.compileNode(source.parent)
end
- baseNode:merge(vm.getNode(source))
- vm.removeNode(source)
-
- baseNode:setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue)
-
- return baseNode
+ vm.getNode(source):setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue)
end
local compilerSwitch = util.switch()
@@ -867,41 +945,79 @@ local compilerSwitch = util.switch()
end)
: case 'paren'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
if source.exp then
vm.setNode(source, vm.compileNode(source.exp))
end
end)
: case 'local'
: case 'self'
+ ---@param source parser.object
: call(function (source)
- local baseNode = compileLocalBase(source)
- vm.setNode(source, baseNode, true)
- if not baseNode:getData 'hasDefined' and source.ref then
+ compileLocal(source)
+ local refs = source.ref
+ if not refs then
+ return
+ end
+
+ local hasMark = vm.getNode(source):getData 'hasDefined'
+
+ local runner = vm.createRunner(source)
+ runner:launch(function (src, node)
+ if src.type == 'setlocal' then
+ if src.bindDocs then
+ for _, doc in ipairs(src.bindDocs) do
+ if doc.type == 'doc.type' then
+ vm.setNode(src, vm.compileNode(doc), true)
+ return vm.getNode(src)
+ end
+ end
+ end
+ if src.value and guide.isLiteral(src.value) then
+ if src.value.type == 'table' then
+ vm.setNode(src, vm.createNode(src.value), true)
+ else
+ vm.setNode(src, vm.compileNode(src.value), true)
+ end
+ elseif src.value
+ and src.value.type == 'binary'
+ and src.value.op and src.value.op.type == 'or'
+ and src.value[1] and src.value[1].type == 'getlocal' and src.value[1].node == source then
+ -- x = x or 1
+ vm.setNode(src, vm.compileNode(src.value))
+ else
+ vm.setNode(src, node, true)
+ end
+ return vm.getNode(src)
+ elseif src.type == 'getlocal' then
+ if bindAs(src) then
+ return
+ end
+ vm.setNode(src, node, true)
+ end
+ end)
+
+ if not hasMark then
+ local parentFunc = guide.getParentFunction(source)
for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal' then
- vm.setNode(source, vm.compileNode(ref))
+ if ref.type == 'setlocal'
+ and guide.getParentFunction(ref) == parentFunc then
+ vm.setNode(source, vm.getNode(ref))
end
end
end
end)
: case 'setlocal'
: call(function (source)
- local baseNode = compileLocalBase(source.node)
- if not baseNode:getData 'hasDefined' and source.value then
- if source.value.type == 'table' then
- vm.setNode(source, source.value)
- else
- vm.setNode(source, vm.compileNode(source.value))
- end
- end
- baseNode:merge(vm.getNode(source))
- vm.setNode(source, baseNode, true)
vm.compileNode(source.node)
end)
: case 'getlocal'
: call(function (source)
- local baseNode = compileLocalBase(source.node)
- vm.setNode(source, baseNode, true)
+ if bindAs(source) then
+ return
+ end
vm.compileNode(source.node)
end)
: case 'setfield'
@@ -924,6 +1040,9 @@ local compilerSwitch = util.switch()
: case 'getmethod'
: case 'getindex'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
compileByLocalID(source)
local key = guide.getKeyName(source)
if key == nil and source.index then
@@ -959,6 +1078,9 @@ local compilerSwitch = util.switch()
end)
: case 'getglobal'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
if source.node[1] ~= '_ENV' then
return
end
@@ -1019,7 +1141,7 @@ local compilerSwitch = util.switch()
end)
end
if hasGeneric then
- vm.setNode(source, genericMgr(rtn, sign))
+ vm.setNode(source, vm.createGeneric(rtn, sign))
else
vm.setNode(source, vm.compileNode(rtn))
end
@@ -1092,29 +1214,44 @@ local compilerSwitch = util.switch()
-- for k, v in pairs(t) do
--> for k, v in iterator, status, initValue do
--> local k, v = iterator(status, initValue)
- source._iterator = {}
- source._iterArgs = {{}, {}}
- -- iterator
- selectNode(source._iterator, source.exps, 1)
- -- status
- selectNode(source._iterArgs[1], source.exps, 2)
- -- initValue
- selectNode(source._iterArgs[2], source.exps, 3)
- end
+ source._iterator = {
+ type = 'dummyfunc',
+ parent = source,
+ }
+ source._iterArgs = {{},{}}
+ end
+ -- iterator
+ selectNode(source._iterator, source.exps, 1)
+ -- status
+ selectNode(source._iterArgs[1], source.exps, 2)
+ -- initValue
+ selectNode(source._iterArgs[2], source.exps, 3)
if source.keys then
for i, loc in ipairs(source.keys) do
local node = getReturn(source._iterator, i, source._iterArgs)
if node then
+ if i == 1 then
+ node:removeOptional()
+ end
vm.setNode(loc, node)
end
end
end
end)
+ : case 'loop'
+ : call(function (source)
+ if source.loc then
+ vm.setNode(source.loc, vm.declareGlobal('type', 'integer'))
+ end
+ end)
: case 'doc.type'
: call(function (source)
for _, typeUnit in ipairs(source.types) do
vm.setNode(source, vm.compileNode(typeUnit))
end
+ if source.optional then
+ vm.getNode(source):addOptional()
+ end
end)
: case 'doc.type.integer'
: case 'doc.type.string'
@@ -1130,7 +1267,13 @@ local compilerSwitch = util.switch()
: call(function (source)
local uri = guide.getUri(source)
vm.setNode(source, source)
- local global = globalMgr.getGlobal('type', source.node[1])
+ if not source.node[1] then
+ return
+ end
+ local global = vm.getGlobal('type', source.node[1])
+ if not global then
+ return
+ end
for _, set in ipairs(global:getSets(uri)) do
if set.type == 'doc.class' then
if set.extends then
@@ -1161,14 +1304,22 @@ local compilerSwitch = util.switch()
if not source.extends then
return
end
- vm.setNode(source, vm.compileNode(source.extends))
+ local fieldNode = vm.compileNode(source.extends)
+ if source.optional then
+ fieldNode:addOptional()
+ end
+ vm.setNode(source, fieldNode)
end)
: case 'doc.type.field'
: call(function (source)
if not source.extends then
return
end
- vm.setNode(source, vm.compileNode(source.extends))
+ local fieldNode = vm.compileNode(source.extends)
+ if source.optional then
+ fieldNode:addOptional()
+ end
+ vm.setNode(source, fieldNode)
end)
: case 'doc.param'
: call(function (source)
@@ -1208,7 +1359,7 @@ local compilerSwitch = util.switch()
end)
: case 'doc.see.name'
: call(function (source)
- local type = globalMgr.getGlobal('type', source[1])
+ local type = vm.getGlobal('type', source[1])
if type then
vm.setNode(source, vm.compileNode(type))
end
@@ -1218,7 +1369,10 @@ local compilerSwitch = util.switch()
if source.extends then
vm.setNode(source, vm.compileNode(source.extends))
else
- vm.setNode(source, globalMgr.getGlobal('type', 'any'))
+ vm.setNode(source, vm.declareGlobal('type', 'any'))
+ end
+ if source.optional then
+ vm.getNode(source):addOptional()
end
end)
: case 'generic'
@@ -1227,10 +1381,16 @@ local compilerSwitch = util.switch()
end)
: case 'unary'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
+ if not source[1] then
+ return
+ end
if source.op.type == 'not' then
local result = vm.test(source[1])
if result == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'boolean'))
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
return
else
vm.setNode(source, {
@@ -1244,13 +1404,13 @@ local compilerSwitch = util.switch()
end
end
if source.op.type == '#' then
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
if source.op.type == '-' then
local v = vm.getNumber(source[1])
if v == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
else
vm.setNode(source, {
@@ -1266,7 +1426,7 @@ local compilerSwitch = util.switch()
if source.op.type == '~' then
local v = vm.getInteger(source[1])
if v == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
else
vm.setNode(source, {
@@ -1282,34 +1442,42 @@ local compilerSwitch = util.switch()
end)
: case 'binary'
: call(function (source)
+ if bindAs(source) then
+ return
+ end
+ if not source[1] or not source[2] then
+ return
+ end
if source.op.type == 'and' then
+ local node1 = vm.compileNode(source[1])
+ local node2 = vm.compileNode(source[2])
local r1 = vm.test(source[1])
if r1 == true then
- vm.setNode(source, vm.compileNode(source[2]))
- return
- end
- if r1 == false then
- vm.setNode(source, vm.compileNode(source[1]))
- return
+ vm.setNode(source, node2)
+ elseif r1 == false then
+ vm.setNode(source, node1)
+ else
+ vm.setNode(source, node2)
end
- return
end
if source.op.type == 'or' then
+ local node1 = vm.compileNode(source[1])
+ local node2 = vm.compileNode(source[2])
local r1 = vm.test(source[1])
if r1 == true then
- vm.setNode(source, vm.compileNode(source[1]))
- return
- end
- if r1 == false then
- vm.setNode(source, vm.compileNode(source[2]))
- return
+ vm.setNode(source, node1)
+ elseif r1 == false then
+ vm.setNode(source, node2)
+ else
+ vm.getNode(source):merge(node1)
+ vm.getNode(source):setTruthy()
+ vm.getNode(source):merge(node2)
end
- return
end
if source.op.type == '==' then
local result = vm.equal(source[1], source[2])
if result == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'boolean'))
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
return
else
vm.setNode(source, {
@@ -1325,7 +1493,7 @@ local compilerSwitch = util.switch()
if source.op.type == '~=' then
local result = vm.equal(source[1], source[2])
if result == nil then
- vm.setNode(source, globalMgr.getGlobal('type', 'boolean'))
+ vm.setNode(source, vm.declareGlobal('type', 'boolean'))
return
else
vm.setNode(source, {
@@ -1351,7 +1519,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1368,7 +1536,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1385,7 +1553,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1402,7 +1570,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1419,7 +1587,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'integer'))
+ vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
@@ -1437,7 +1605,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1455,7 +1623,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1473,7 +1641,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1490,14 +1658,14 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
if source.op.type == '%' then
local a = vm.getNumber(source[1])
local b = vm.getNumber(source[2])
- if a and b then
+ if a and b and b ~= 0 then
local result = a % b
vm.setNode(source, {
type = math.type(result) == 'integer' and 'integer' or 'number',
@@ -1508,7 +1676,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1525,7 +1693,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1543,7 +1711,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'number'))
+ vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
@@ -1580,7 +1748,7 @@ local compilerSwitch = util.switch()
})
return
else
- vm.setNode(source, globalMgr.getGlobal('type', 'string'))
+ vm.setNode(source, vm.declareGlobal('type', 'string'))
return
end
end
@@ -1614,17 +1782,20 @@ local function compileByGlobal(source)
vm.setNode(source, globalNode, true)
return
end
+ ---@type vm.node
globalNode = vm.createNode(global)
vm.setNode(root._globalBase[name], globalNode, true)
+ vm.setNode(source, globalNode, true)
- local sets = global.links[uri].sets or {}
- local gets = global.links[uri].gets or {}
- for _, set in ipairs(sets) do
- vm.setNode(set, globalNode, true)
- end
- for _, get in ipairs(gets) do
- vm.setNode(get, globalNode, true)
- end
+ -- TODO:don't mix
+ --local sets = global.links[uri].sets or {}
+ --local gets = global.links[uri].gets or {}
+ --for _, set in ipairs(sets) do
+ -- vm.setNode(set, globalNode, true)
+ --end
+ --for _, get in ipairs(gets) do
+ -- vm.setNode(get, globalNode, true)
+ --end
if global.cate == 'variable' then
local hasMarkDoc
@@ -1672,7 +1843,11 @@ end
---@return vm.node
function vm.compileNode(source)
if not source then
- error('Can not compile nil node')
+ if TEST then
+ error('Can not compile nil source')
+ else
+ log.error('Can not compile nil source')
+ end
end
if source.type == 'global' then
diff --git a/script/vm/def.lua b/script/vm/def.lua
index b66e8fda..83e92686 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -2,8 +2,6 @@
local vm = require 'vm.vm'
local util = require 'utility'
local guide = require 'parser.guide'
-local localID = require 'vm.local-id'
-local globalMgr = require 'vm.global-manager'
local simpleSwitch
@@ -79,6 +77,13 @@ simpleSwitch = util.switch()
pushResult(source.node)
end
end)
+ : case 'doc.cast.name'
+ : call(function (source, pushResult)
+ local loc = guide.getLocal(source, source[1], source.start)
+ if loc then
+ pushResult(loc)
+ end
+ end)
local searchFieldSwitch = util.switch()
: case 'table'
@@ -97,7 +102,7 @@ local searchFieldSwitch = util.switch()
---@param key string
: call(function (suri, obj, key, pushResult)
if obj.cate == 'variable' then
- local newGlobal = globalMgr.getGlobal('variable', obj.name, key)
+ local newGlobal = vm.getGlobal('variable', obj.name, key)
if newGlobal then
for _, set in ipairs(newGlobal:getSets(suri)) do
pushResult(set)
@@ -110,7 +115,7 @@ local searchFieldSwitch = util.switch()
end)
: case 'local'
: call(function (suri, obj, key, pushResult)
- local sources = localID.getSources(obj, key)
+ local sources = vm.getLocalSources(obj, key)
if sources then
for _, src in ipairs(sources) do
if guide.isSet(src) then
@@ -189,7 +194,7 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
- local idSources = localID.getSources(source)
+ local idSources = vm.getLocalSources(source)
if not idSources then
return
end
diff --git a/script/vm/doc.lua b/script/vm/doc.lua
index 5a92a103..e2b383b6 100644
--- a/script/vm/doc.lua
+++ b/script/vm/doc.lua
@@ -3,7 +3,6 @@ local guide = require 'parser.guide'
---@class vm
local vm = require 'vm.vm'
local config = require 'config'
-local globalMgr = require 'vm.global-manager'
---获取class与alias
---@param suri uri
@@ -11,13 +10,13 @@ local globalMgr = require 'vm.global-manager'
---@return parser.object[]
function vm.getDocSets(suri, name)
if name then
- local global = globalMgr.getGlobal('type', name)
+ local global = vm.getGlobal('type', name)
if not global then
return {}
end
return global:getSets(suri)
else
- return globalMgr.getGlobalSets(suri, 'type')
+ return vm.getGlobalSets(suri, 'type')
end
end
@@ -27,6 +26,9 @@ function vm.isMetaFile(uri)
return false
end
local cache = files.getCache(uri)
+ if not cache then
+ return false
+ end
if cache.isMeta ~= nil then
return cache.isMeta
end
@@ -332,6 +334,9 @@ function vm.isDiagDisabledAt(uri, position, name)
return false
end
local cache = files.getCache(uri)
+ if not cache then
+ return false
+ end
if not cache.diagnosticRanges then
cache.diagnosticRanges = {}
for _, doc in ipairs(status.ast.docs) do
diff --git a/script/vm/field.lua b/script/vm/field.lua
index ba7cd4c1..5de838be 100644
--- a/script/vm/field.lua
+++ b/script/vm/field.lua
@@ -15,6 +15,15 @@ local searchByNodeSwitch = util.switch()
pushResult(source)
end)
+local function searchByLocalID(source, pushResult)
+ local fields = vm.getLocalFields(source)
+ if fields then
+ for _, field in ipairs(fields) do
+ pushResult(field)
+ end
+ end
+end
+
local function searchByNode(source, pushResult)
local uri = guide.getUri(source)
vm.compileByParentNode(source, nil, true, function (field)
@@ -35,6 +44,7 @@ function vm.getFields(source)
end
end
+ searchByLocalID(source, pushResult)
searchByNode(source, pushResult)
return results
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index b3981ff8..6462028e 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -1,3 +1,4 @@
+---@class vm
local vm = require 'vm.vm'
---@class parser.object
@@ -114,7 +115,7 @@ end
---@param uri uri
---@param args parser.object
----@return parser.object
+---@return vm.node
function mt:resolve(uri, args)
local resolved = self.sign:resolve(uri, args)
local protoNode = vm.compileNode(self.proto)
@@ -129,7 +130,7 @@ end
---@param proto vm.object
---@param sign vm.sign
---@return vm.generic
-return function (proto, sign)
+function vm.createGeneric(proto, sign)
local generic = setmetatable({
sign = sign,
proto = proto,
diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua
deleted file mode 100644
index f25bb5a0..00000000
--- a/script/vm/global-manager.lua
+++ /dev/null
@@ -1,364 +0,0 @@
-local util = require 'utility'
-local guide = require 'parser.guide'
-local globalBuilder = require 'vm.global'
-local signMgr = require 'vm.sign'
-local genericMgr = require 'vm.generic'
----@class vm
-local vm = require 'vm.vm'
-
----@class parser.object
----@field _globalNode vm.global
-
----@class vm.global-manager
-local m = {}
----@type table<string, vm.global>
-m.globals = {}
----@type table<uri, table<string, boolean>>
-m.globalSubs = util.multiTable(2)
-
-local compilerGlobalSwitch = util.switch()
- : case 'local'
- : call(function (source)
- if source.special ~= '_G' then
- return
- end
- if source.ref then
- for _, ref in ipairs(source.ref) do
- m.compileObject(ref)
- end
- end
- end)
- : case 'getlocal'
- : call(function (source)
- if source.special ~= '_G' then
- return
- end
- if not source.next then
- return
- end
- m.compileObject(source.next)
- end)
- : case 'setglobal'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = guide.getKeyName(source)
- local global = m.declareGlobal('variable', name, uri)
- global:addSet(uri, source)
- source._globalNode = global
- end)
- : case 'getglobal'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = guide.getKeyName(source)
- local global = m.declareGlobal('variable', name, uri)
- global:addGet(uri, source)
- source._globalNode = global
-
- local nxt = source.next
- if nxt then
- m.compileObject(nxt)
- end
- end)
- : case 'setfield'
- : case 'setmethod'
- : case 'setindex'
- ---@param source parser.object
- : call(function (source)
- local name
- local keyName = guide.getKeyName(source)
- if not keyName then
- return
- end
- if source.node._globalNode then
- local parentName = source.node._globalNode:getName()
- if parentName == '_G' then
- name = keyName
- else
- name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
- end
- elseif source.node.special == '_G' then
- name = keyName
- end
- if not name then
- return
- end
- local uri = guide.getUri(source)
- local global = m.declareGlobal('variable', name, uri)
- global:addSet(uri, source)
- source._globalNode = global
- end)
- : case 'getfield'
- : case 'getmethod'
- : case 'getindex'
- ---@param source parser.object
- : call(function (source)
- local name
- local keyName = guide.getKeyName(source)
- if not keyName then
- return
- end
- if source.node._globalNode then
- local parentName = source.node._globalNode:getName()
- if parentName == '_G' then
- name = keyName
- else
- name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
- end
- elseif source.node.special == '_G' then
- name = keyName
- end
- local uri = guide.getUri(source)
- local global = m.declareGlobal('variable', name, uri)
- global:addGet(uri, source)
- source._globalNode = global
-
- local nxt = source.next
- if nxt then
- m.compileObject(nxt)
- end
- end)
- : case 'call'
- : call(function (source)
- if source.node.special == 'rawset'
- or source.node.special == 'rawget' then
- if not source.args then
- return
- end
- local g = source.args[1]
- local key = source.args[2]
- if g and key and g.special == '_G' then
- local name = guide.getKeyName(key)
- if name then
- local uri = guide.getUri(source)
- local global = m.declareGlobal('variable', name, uri)
- if source.node.special == 'rawset' then
- global:addSet(uri, source)
- source.value = source.args[3]
- else
- global:addGet(uri, source)
- end
- source._globalNode = global
-
- local nxt = source.next
- if nxt then
- m.compileObject(nxt)
- end
- end
- end
- end
- end)
- : case 'doc.class'
- ---@param source parser.object
- : call(function (source)
- local uri = guide.getUri(source)
- local name = guide.getKeyName(source)
- local class = m.declareGlobal('type', name, uri)
- class:addSet(uri, source)
- source._globalNode = class
-
- if source.signs then
- source._sign = signMgr()
- for _, sign in ipairs(source.signs) do
- source._sign:addSign(vm.compileNode(sign))
- end
- if source.extends then
- for _, ext in ipairs(source.extends) do
- if ext.type == 'doc.type.table' then
- ext._generic = genericMgr(ext, source._sign)
- end
- end
- end
- end
- end)
- : case 'doc.alias'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = guide.getKeyName(source)
- local alias = m.declareGlobal('type', name, uri)
- alias:addSet(uri, source)
- source._globalNode = alias
-
- if source.signs then
- source._sign = signMgr()
- for _, sign in ipairs(source.signs) do
- source._sign:addSign(vm.compileNode(sign))
- end
- source.extends._generic = genericMgr(source.extends, source._sign)
- end
- end)
- : case 'doc.type.name'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = source[1]
- local type = m.declareGlobal('type', name, uri)
- type:addGet(uri, source)
- source._globalNode = type
- end)
- : case 'doc.extends.name'
- : call(function (source)
- local uri = guide.getUri(source)
- local name = source[1]
- local class = m.declareGlobal('type', name, uri)
- class:addGet(uri, source)
- source._globalNode = class
- end)
-
-
----@alias vm.global.cate '"variable"' | '"type"'
-
----@param cate vm.global.cate
----@param name string
----@param uri uri
----@return vm.global
-function m.declareGlobal(cate, name, uri)
- local key = cate .. '|' .. name
- m.globalSubs[uri][key] = true
- if not m.globals[key] then
- m.globals[key] = globalBuilder(name, cate)
- end
- return m.globals[key]
-end
-
----@param cate vm.global.cate
----@param name string
----@param field? string
----@return vm.global?
-function m.getGlobal(cate, name, field)
- local key = cate .. '|' .. name
- if field then
- key = key .. vm.ID_SPLITE .. field
- end
- return m.globals[key]
-end
-
----@param cate vm.global.cate
----@param name string
----@return vm.global[]
-function m.getFields(cate, name)
- local globals = {}
- local key = cate .. '|' .. name
-
- -- TODO: optimize
- local clock = os.clock()
- for gid, global in pairs(m.globals) do
- if gid ~= key
- and util.stringStartWith(gid, key)
- and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE
- and not gid:find(vm.ID_SPLITE, #key + 2) then
- globals[#globals+1] = global
- end
- end
- local cost = os.clock() - clock
- if cost > 0.1 then
- log.warn('global-manager getFields cost %.3f', cost)
- end
-
- return globals
-end
-
----@param cate vm.global.cate
----@return vm.global[]
-function m.getGlobals(cate)
- local globals = {}
-
- -- TODO: optimize
- local clock = os.clock()
- for gid, global in pairs(m.globals) do
- if util.stringStartWith(gid, cate)
- and not gid:find(vm.ID_SPLITE) then
- globals[#globals+1] = global
- end
- end
- local cost = os.clock() - clock
- if cost > 0.1 then
- log.warn('global-manager getGlobals cost %.3f', cost)
- end
-
- return globals
-end
-
----@param suri uri
----@param cate vm.global.cate
----@return parser.object[]
-function m.getGlobalSets(suri, cate)
- local globals = m.getGlobals(cate)
- local result = {}
- for _, global in ipairs(globals) do
- local sets = global:getSets(suri)
- for _, set in ipairs(sets) do
- result[#result+1] = set
- end
- end
- return result
-end
-
----@param suri uri
----@param cate vm.global.cate
----@param name string
----@return boolean
-function m.hasGlobalSets(suri, cate, name)
- local global = m.getGlobal(cate, name)
- if not global then
- return false
- end
- local sets = global:getSets(suri)
- if #sets == 0 then
- return false
- end
- return true
-end
-
----@param source parser.object
-function m.compileObject(source)
- if source._globalNode ~= nil then
- return
- end
- source._globalNode = false
- compilerGlobalSwitch(source.type, source)
-end
-
----@param source parser.object
-function m.compileAst(source)
- local env = guide.getENV(source)
- m.compileObject(env)
- guide.eachSpecialOf(source, 'rawset', function (src)
- m.compileObject(src.parent)
- end)
- guide.eachSpecialOf(source, 'rawget', function (src)
- m.compileObject(src.parent)
- end)
- guide.eachSourceTypes(source.docs, {
- 'doc.class',
- 'doc.alias',
- 'doc.type.name',
- 'doc.extends.name',
- }, function (src)
- m.compileObject(src)
- end)
-end
-
----@return vm.global
-function m.getNode(source)
- if source.type == 'field'
- or source.type == 'method' then
- source = source.parent
- end
- return source._globalNode
-end
-
----@param uri uri
-function m.dropUri(uri)
- local globalSub = m.globalSubs[uri]
- m.globalSubs[uri] = nil
- for key in pairs(globalSub) do
- local global = m.globals[key]
- if global then
- global:dropUri(uri)
- if not global:isAlive() then
- m.globals[key] = nil
- end
- end
- end
-end
-
-return m
diff --git a/script/vm/global.lua b/script/vm/global.lua
index 1c46c9a3..a54ab552 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -1,5 +1,9 @@
-local util = require 'utility'
-local scope= require 'workspace.scope'
+local util = require 'utility'
+local scope = require 'workspace.scope'
+local guide = require 'parser.guide'
+local files = require 'files'
+---@class vm
+local vm = require 'vm.vm'
---@class vm.global.link
---@field gets parser.object[]
@@ -15,8 +19,6 @@ mt.__index = mt
mt.type = 'global'
mt.name = ''
-local ID_SPLITE = '\x1F'
-
---@param uri uri
---@param source parser.object
function mt:addSet(uri, source)
@@ -106,7 +108,7 @@ end
---@return string
function mt:getKeyName()
- return self.name:match('[^' .. ID_SPLITE .. ']+$')
+ return self.name:match('[^' .. vm.ID_SPLITE .. ']+$')
end
---@return boolean
@@ -116,10 +118,427 @@ end
---@param cate vm.global.cate
---@return vm.global
-return function (name, cate)
+local function createGlobal(name, cate)
return setmetatable({
name = name,
cate = cate,
links = util.multiTable(2),
}, mt)
end
+
+---@class parser.object
+---@field _globalNode vm.global
+
+---@type table<string, vm.global>
+local allGlobals = {}
+---@type table<uri, table<string, boolean>>
+local globalSubs = util.multiTable(2)
+
+local compileObject
+local compilerGlobalSwitch = util.switch()
+ : case 'local'
+ : call(function (source)
+ if source.special ~= '_G' then
+ return
+ end
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ compileObject(ref)
+ end
+ end
+ end)
+ : case 'getlocal'
+ : call(function (source)
+ if source.special ~= '_G' then
+ return
+ end
+ if not source.next then
+ return
+ end
+ compileObject(source.next)
+ end)
+ : case 'setglobal'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ global:addSet(uri, source)
+ source._globalNode = global
+ end)
+ : case 'getglobal'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ global:addGet(uri, source)
+ source._globalNode = global
+
+ local nxt = source.next
+ if nxt then
+ compileObject(nxt)
+ end
+ end)
+ : case 'setfield'
+ : case 'setmethod'
+ : case 'setindex'
+ ---@param source parser.object
+ : call(function (source)
+ local name
+ local keyName = guide.getKeyName(source)
+ if not keyName then
+ return
+ end
+ if source.node._globalNode then
+ local parentName = source.node._globalNode:getName()
+ if parentName == '_G' then
+ name = keyName
+ else
+ name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
+ end
+ elseif source.node.special == '_G' then
+ name = keyName
+ end
+ if not name then
+ return
+ end
+ local uri = guide.getUri(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ global:addSet(uri, source)
+ source._globalNode = global
+ end)
+ : case 'getfield'
+ : case 'getmethod'
+ : case 'getindex'
+ ---@param source parser.object
+ : call(function (source)
+ local name
+ local keyName = guide.getKeyName(source)
+ if not keyName then
+ return
+ end
+ if source.node._globalNode then
+ local parentName = source.node._globalNode:getName()
+ if parentName == '_G' then
+ name = keyName
+ else
+ name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
+ end
+ elseif source.node.special == '_G' then
+ name = keyName
+ end
+ local uri = guide.getUri(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ global:addGet(uri, source)
+ source._globalNode = global
+
+ local nxt = source.next
+ if nxt then
+ compileObject(nxt)
+ end
+ end)
+ : case 'call'
+ : call(function (source)
+ if source.node.special == 'rawset'
+ or source.node.special == 'rawget' then
+ if not source.args then
+ return
+ end
+ local g = source.args[1]
+ local key = source.args[2]
+ if g and key and g.special == '_G' then
+ local name = guide.getKeyName(key)
+ if name then
+ local uri = guide.getUri(source)
+ local global = vm.declareGlobal('variable', name, uri)
+ if source.node.special == 'rawset' then
+ global:addSet(uri, source)
+ source.value = source.args[3]
+ else
+ global:addGet(uri, source)
+ end
+ source._globalNode = global
+
+ local nxt = source.next
+ if nxt then
+ compileObject(nxt)
+ end
+ end
+ end
+ end
+ end)
+ : case 'doc.class'
+ ---@param source parser.object
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ local class = vm.declareGlobal('type', name, uri)
+ class:addSet(uri, source)
+ source._globalNode = class
+
+ if source.signs then
+ source._sign = vm.createSign()
+ for _, sign in ipairs(source.signs) do
+ source._sign:addSign(vm.compileNode(sign))
+ end
+ if source.extends then
+ for _, ext in ipairs(source.extends) do
+ if ext.type == 'doc.type.table' then
+ ext._generic = vm.createGeneric(ext, source._sign)
+ end
+ end
+ end
+ end
+ end)
+ : case 'doc.alias'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = guide.getKeyName(source)
+ local alias = vm.declareGlobal('type', name, uri)
+ alias:addSet(uri, source)
+ source._globalNode = alias
+
+ if source.signs then
+ source._sign = vm.createSign()
+ for _, sign in ipairs(source.signs) do
+ source._sign:addSign(vm.compileNode(sign))
+ end
+ source.extends._generic = vm.createGeneric(source.extends, source._sign)
+ end
+ end)
+ : case 'doc.type.name'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = source[1]
+ local type = vm.declareGlobal('type', name, uri)
+ type:addGet(uri, source)
+ source._globalNode = type
+ end)
+ : case 'doc.extends.name'
+ : call(function (source)
+ local uri = guide.getUri(source)
+ local name = source[1]
+ local class = vm.declareGlobal('type', name, uri)
+ class:addGet(uri, source)
+ source._globalNode = class
+ end)
+
+
+---@alias vm.global.cate '"variable"' | '"type"'
+
+---@param cate vm.global.cate
+---@param name string
+---@param uri? uri
+---@return vm.global
+function vm.declareGlobal(cate, name, uri)
+ local key = cate .. '|' .. name
+ if uri then
+ globalSubs[uri][key] = true
+ end
+ if not allGlobals[key] then
+ allGlobals[key] = createGlobal(name, cate)
+ end
+ return allGlobals[key]
+end
+
+---@param cate vm.global.cate
+---@param name string
+---@param field? string
+---@return vm.global?
+function vm.getGlobal(cate, name, field)
+ local key = cate .. '|' .. name
+ if field then
+ key = key .. vm.ID_SPLITE .. field
+ end
+ return allGlobals[key]
+end
+
+---@param cate vm.global.cate
+---@param name string
+---@return vm.global[]
+function vm.getGlobalFields(cate, name)
+ local globals = {}
+ local key = cate .. '|' .. name
+
+ local clock = os.clock()
+ for gid, global in pairs(allGlobals) do
+ if gid ~= key
+ and util.stringStartWith(gid, key)
+ and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE
+ and not gid:find(vm.ID_SPLITE, #key + 2) then
+ globals[#globals+1] = global
+ end
+ end
+ local cost = os.clock() - clock
+ if cost > 0.1 then
+ log.warn('global-manager getFields cost %.3f', cost)
+ end
+
+ return globals
+end
+
+---@param cate vm.global.cate
+---@return vm.global[]
+function vm.getGlobals(cate)
+ local globals = {}
+
+ local clock = os.clock()
+ for gid, global in pairs(allGlobals) do
+ if util.stringStartWith(gid, cate)
+ and not gid:find(vm.ID_SPLITE) then
+ globals[#globals+1] = global
+ end
+ end
+ local cost = os.clock() - clock
+ if cost > 0.1 then
+ log.warn('global-manager getGlobals cost %.3f', cost)
+ end
+
+ return globals
+end
+
+---@param suri uri
+---@param cate vm.global.cate
+---@return parser.object[]
+function vm.getGlobalSets(suri, cate)
+ local globals = vm.getGlobals(cate)
+ local result = {}
+ for _, global in ipairs(globals) do
+ local sets = global:getSets(suri)
+ for _, set in ipairs(sets) do
+ result[#result+1] = set
+ end
+ end
+ return result
+end
+
+---@param suri uri
+---@param cate vm.global.cate
+---@param name string
+---@return boolean
+function vm.hasGlobalSets(suri, cate, name)
+ local global = vm.getGlobal(cate, name)
+ if not global then
+ return false
+ end
+ local sets = global:getSets(suri)
+ if #sets == 0 then
+ return false
+ end
+ return true
+end
+
+---@param source parser.object
+function compileObject(source)
+ if source._globalNode ~= nil then
+ return
+ end
+ source._globalNode = false
+ compilerGlobalSwitch(source.type, source)
+end
+
+---@param source parser.object
+local function compileSelf(source)
+ if source.parent.type ~= 'funcargs' then
+ return
+ end
+ ---@type parser.object
+ local node = source.parent.parent and source.parent.parent.parent and source.parent.parent.parent.node
+ if not node then
+ return
+ end
+ local fields = vm.getLocalFields(source)
+ if not fields then
+ return
+ end
+ local nodeLocalID = vm.getLocalID(node)
+ local globalNode = node._globalNode
+ if not nodeLocalID and not globalNode then
+ return
+ end
+ for _, field in ipairs(fields) do
+ if field.type == 'setfield' then
+ local key = guide.getKeyName(field)
+ if key then
+ if nodeLocalID then
+ local myID = nodeLocalID .. vm.ID_SPLITE .. key
+ vm.insertLocalID(myID, field)
+ end
+ if globalNode then
+ local myID = globalNode:getName() .. vm.ID_SPLITE .. key
+ local myGlobal = vm.declareGlobal('variable', myID, guide.getUri(node))
+ myGlobal:addSet(guide.getUri(node), field)
+ end
+ end
+ end
+ end
+end
+
+---@param source parser.object
+local function compileAst(source)
+ local env = guide.getENV(source)
+ if not env then
+ return
+ end
+ compileObject(env)
+ guide.eachSpecialOf(source, 'rawset', function (src)
+ compileObject(src.parent)
+ end)
+ guide.eachSpecialOf(source, 'rawget', function (src)
+ compileObject(src.parent)
+ end)
+ guide.eachSourceTypes(source.docs, {
+ 'doc.class',
+ 'doc.alias',
+ 'doc.type.name',
+ 'doc.extends.name',
+ }, function (src)
+ compileObject(src)
+ end)
+
+ --[[
+ local mt
+ function mt:xxx()
+ self.a = 1
+ end
+
+ mt.a --> find this definition
+ ]]
+ guide.eachSourceType(source, 'self', function (src)
+ compileSelf(src)
+ end)
+end
+
+---@param uri uri
+local function dropUri(uri)
+ local globalSub = globalSubs[uri]
+ globalSubs[uri] = nil
+ for key in pairs(globalSub) do
+ local global = allGlobals[key]
+ if global then
+ global:dropUri(uri)
+ if not global:isAlive() then
+ allGlobals[key] = nil
+ end
+ end
+ end
+end
+
+for uri in files.eachFile() do
+ local state = files.getState(uri)
+ if state then
+ compileAst(state.ast)
+ end
+end
+
+files.watch(function (ev, uri)
+ if ev == 'update' then
+ dropUri(uri)
+ local state = files.getState(uri)
+ if state then
+ compileAst(state.ast)
+ end
+ end
+ if ev == 'remove' then
+ dropUri(uri)
+ end
+end)
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 2a64ed52..fabc9828 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -1,11 +1,9 @@
local util = require 'utility'
local config = require 'config'
local guide = require 'parser.guide'
+---@class vm
local vm = require 'vm.vm'
----@class vm.infer-manager
-local m = {}
-
---@class vm.infer
---@field views table<string, boolean>
---@field cachedView? string
@@ -21,7 +19,7 @@ mt._hasDocFunction = false
mt._isParam = false
mt._isLocal = false
-m.NULL = setmetatable({}, mt)
+vm.NULL = setmetatable({}, mt)
local inferSorted = {
['boolean'] = - 100,
@@ -52,7 +50,7 @@ local viewNodeSwitch = util.switch()
: call(function (source, infer)
if source.type == 'table' then
if #source == 1 and source[1].type == 'varargs' then
- local node = m.getInfer(source[1]):view()
+ local node = vm.getInfer(source[1]):view()
return ('%s[]'):format(node)
end
end
@@ -90,7 +88,7 @@ local viewNodeSwitch = util.switch()
if source.signs then
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = m.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view()
end
return ('%s<%s>'):format(source[1], table.concat(buf, ', '))
else
@@ -99,7 +97,7 @@ local viewNodeSwitch = util.switch()
end)
: case 'generic'
: call(function (source, infer)
- return m.getInfer(source.proto):view()
+ return vm.getInfer(source.proto):view()
end)
: case 'doc.generic.name'
: call(function (source, infer)
@@ -108,7 +106,7 @@ local viewNodeSwitch = util.switch()
: case 'doc.type.array'
: call(function (source, infer)
infer._hasClass = true
- local view = m.getInfer(source.node):view()
+ local view = vm.getInfer(source.node):view()
if source.node.type == 'doc.type' then
view = '(' .. view .. ')'
end
@@ -119,7 +117,7 @@ local viewNodeSwitch = util.switch()
infer._hasClass = true
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = m.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view()
end
return ('%s<%s>'):format(source.node[1], table.concat(buf, ', '))
end)
@@ -144,20 +142,23 @@ local viewNodeSwitch = util.switch()
local argView = ''
local regView = ''
for i, arg in ipairs(source.args) do
+ local argNode = vm.compileNode(arg)
+ local isOptional = argNode:isOptional()
+ if isOptional then
+ argNode = argNode:copy()
+ argNode:removeOptional()
+ end
args[i] = string.format('%s%s: %s'
, arg.name[1]
- , arg.optional and '?' or ''
- , m.getInfer(arg):view()
+ , isOptional and '?' or ''
+ , vm.getInfer(argNode):view()
)
end
if #args > 0 then
argView = table.concat(args, ', ')
end
for i, ret in ipairs(source.returns) do
- rets[i] = string.format('%s%s'
- , m.getInfer(ret):view()
- , ret.optional and '?' or ''
- )
+ rets[i] = vm.getInfer(ret):view()
end
if #rets > 0 then
regView = ':' .. table.concat(rets, ', ')
@@ -165,16 +166,21 @@ local viewNodeSwitch = util.switch()
return ('fun(%s)%s'):format(argView, regView)
end)
----@param source parser.object
+---@param source parser.object | vm.node
---@return vm.infer
-function m.getInfer(source)
- local node = vm.compileNode(source)
+function vm.getInfer(source)
+ local node
+ if source.type == 'vm.node' then
+ node = source
+ else
+ node = vm.compileNode(source)
+ end
if node.lastInfer then
return node.lastInfer
end
local infer = setmetatable({
node = node,
- uri = guide.getUri(source),
+ uri = source.type ~= 'vm.node' and guide.getUri(source),
}, mt)
node.lastInfer = infer
@@ -199,24 +205,24 @@ function mt:_trim()
if self._hasTable and not self._hasClass then
self.views['table'] = true
end
- if self._hasClass then
- self:_eraseAlias()
- end
end
-function mt:_eraseAlias()
- local expandAlias = config.get(self.uri, 'Lua.hover.expandAlias')
+---@param uri uri
+---@return table<string, true>
+function mt:_eraseAlias(uri)
+ local drop = {}
+ local expandAlias = config.get(uri, 'Lua.hover.expandAlias')
for n in self.node:eachObject() do
if n.type == 'global' and n.cate == 'type' then
- for _, set in ipairs(n:getSets(self.uri)) do
+ for _, set in ipairs(n:getSets(uri)) do
if set.type == 'doc.alias' then
if expandAlias then
- self.views[n.name] = nil
+ drop[n.name] = true
else
for _, ext in ipairs(set.extends.types) do
local view = viewNodeSwitch(ext.type, ext, {})
if view and view ~= n.name then
- self.views[view] = nil
+ drop[view] = true
end
end
end
@@ -224,6 +230,7 @@ function mt:_eraseAlias()
end
end
end
+ return drop
end
---@param tp string
@@ -273,17 +280,16 @@ function mt:view(default, uri)
return 'any'
end
- if not next(self.views) then
- return default or 'unknown'
- end
-
- if self.cachedView then
- return self.cachedView
+ local drop
+ if self._hasClass then
+ drop = self:_eraseAlias(uri or self.uri)
end
local array = {}
for view in pairs(self.views) do
- array[#array+1] = view
+ if not drop or not drop[view] then
+ array[#array+1] = view
+ end
end
table.sort(array, function (a, b)
@@ -298,22 +304,29 @@ function mt:view(default, uri)
local max = #array
local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit')
- if max > limit then
- local view = string.format('%s...(+%d)'
- , table.concat(array, '|', 1, limit)
- , max - limit
- )
-
- self.cachedView = view
-
- return view
+ local view
+ if #array == 0 then
+ view = default or 'unknown'
else
- local view = table.concat(array, '|')
-
- self.cachedView = view
+ if max > limit then
+ view = string.format('%s...(+%d)'
+ , table.concat(array, '|', 1, limit)
+ , max - limit
+ )
+ else
+ view = table.concat(array, '|')
+ end
+ end
- return view
+ if self.node:isOptional() then
+ if max > 1 then
+ view = '(' .. view .. ')?'
+ else
+ view = view .. '?'
+ end
end
+
+ return view
end
function mt:eachView()
@@ -324,10 +337,10 @@ end
---@param other vm.infer
---@return vm.infer
function mt:merge(other)
- if self == m.NULL then
+ if self == vm.NULL then
return other
end
- if other == m.NULL then
+ if other == vm.NULL then
return self
end
@@ -390,8 +403,6 @@ end
---@param source parser.object
---@return string?
-function m.viewObject(source)
+function vm.viewObject(source)
return viewNodeSwitch(source.type, source, {})
end
-
-return m
diff --git a/script/vm/init.lua b/script/vm/init.lua
index 0058c698..f5003c11 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -1,4 +1,7 @@
local vm = require 'vm.vm'
+
+---@alias vm.object parser.object | vm.global | vm.generic
+
require 'vm.compiler'
require 'vm.value'
require 'vm.node'
@@ -8,5 +11,10 @@ require 'vm.field'
require 'vm.doc'
require 'vm.type'
require 'vm.library'
-require 'vm.manager'
+require 'vm.runner'
+require 'vm.infer'
+require 'vm.generic'
+require 'vm.sign'
+require 'vm.local-id'
+require 'vm.global'
return vm
diff --git a/script/vm/library.lua b/script/vm/library.lua
index 49f7adb0..e7bf4f42 100644
--- a/script/vm/library.lua
+++ b/script/vm/library.lua
@@ -13,24 +13,3 @@ function vm.getLibraryName(source)
end
return nil
end
-
-local globalLibraryNames = {
- 'arg', 'assert', 'error', 'collectgarbage', 'dofile', '_G', 'getfenv',
- 'getmetatable', 'ipairs', 'load', 'loadfile', 'loadstring',
- 'module', 'next', 'pairs', 'pcall', 'print', 'rawequal',
- 'rawget', 'rawlen', 'rawset', 'select', 'setfenv',
- 'setmetatable', 'tonumber', 'tostring', 'type', '_VERSION',
- 'warn', 'xpcall', 'require', 'unpack', 'bit32', 'coroutine',
- 'debug', 'io', 'math', 'os', 'package', 'string', 'table',
- 'utf8', 'newproxy',
-}
-local globalLibraryNamesMap
-function vm.isGlobalLibraryName(name)
- if not globalLibraryNamesMap then
- globalLibraryNamesMap = {}
- for _, v in ipairs(globalLibraryNames) do
- globalLibraryNamesMap[v] = true
- end
- end
- return globalLibraryNamesMap[name] or false
-end
diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua
index 728de301..80c68769 100644
--- a/script/vm/local-id.lua
+++ b/script/vm/local-id.lua
@@ -1,13 +1,13 @@
local util = require 'utility'
local guide = require 'parser.guide'
+---@class vm
local vm = require 'vm.vm'
---@class parser.object
---@field _localID string
---@field _localIDs table<string, parser.object[]>
----@class vm.local-id
-local m = {}
+local compileLocalID, getLocal
local compileSwitch = util.switch()
: case 'local'
@@ -18,13 +18,13 @@ local compileSwitch = util.switch()
return
end
for _, ref in ipairs(source.ref) do
- m.compileLocalID(ref)
+ compileLocalID(ref)
end
end)
: case 'getlocal'
: call(function (source)
source._localID = ('%d'):format(source.node.start)
- m.compileLocalID(source.next)
+ compileLocalID(source.next)
end)
: case 'getfield'
: case 'setfield'
@@ -40,7 +40,7 @@ local compileSwitch = util.switch()
source._localID = parentID .. vm.ID_SPLITE .. key
source.field._localID = source._localID
if source.type == 'getfield' then
- m.compileLocalID(source.next)
+ compileLocalID(source.next)
end
end)
: case 'getmethod'
@@ -57,7 +57,7 @@ local compileSwitch = util.switch()
source._localID = parentID .. vm.ID_SPLITE .. key
source.method._localID = source._localID
if source.type == 'getmethod' then
- m.compileLocalID(source.next)
+ compileLocalID(source.next)
end
end)
: case 'getindex'
@@ -74,7 +74,7 @@ local compileSwitch = util.switch()
source._localID = parentID .. vm.ID_SPLITE .. key
source.index._localID = source._localID
if source.type == 'setindex' then
- m.compileLocalID(source.next)
+ compileLocalID(source.next)
end
end)
@@ -82,7 +82,7 @@ local leftSwitch = util.switch()
: case 'field'
: case 'method'
: call(function (source)
- return m.getLocal(source.parent)
+ return getLocal(source.parent)
end)
: case 'getfield'
: case 'setfield'
@@ -91,24 +91,36 @@ local leftSwitch = util.switch()
: case 'getindex'
: case 'setindex'
: call(function (source)
- return m.getLocal(source.node)
+ return getLocal(source.node)
end)
: case 'getlocal'
: call(function (source)
return source.node
end)
: case 'local'
+ : case 'self'
: call(function (source)
return source
end)
---@param source parser.object
---@return parser.object?
-function m.getLocal(source)
+function getLocal(source)
return leftSwitch(source.type, source)
end
-function m.compileLocalID(source)
+---@param id string
+---@param source parser.object
+function vm.insertLocalID(id, source)
+ local root = guide.getRoot(source)
+ if not root._localIDs then
+ root._localIDs = util.multiTable(2)
+ end
+ local sources = root._localIDs[id]
+ sources[#sources+1] = source
+end
+
+function compileLocalID(source)
if not source then
return
end
@@ -117,37 +129,33 @@ function m.compileLocalID(source)
return
end
compileSwitch(source.type, source)
- if not source._localID then
+ local id = source._localID
+ if not id then
return
end
- local root = guide.getRoot(source)
- if not root._localIDs then
- root._localIDs = util.multiTable(2)
- end
- local sources = root._localIDs[source._localID]
- sources[#sources+1] = source
+ vm.insertLocalID(id, source)
end
---@param source parser.object
----@return string|boolean
-function m.getID(source)
+---@return string?
+function vm.getLocalID(source)
if source._localID ~= nil then
return source._localID
end
source._localID = false
- local loc = m.getLocal(source)
+ local loc = getLocal(source)
if not loc then
return source._localID
end
- m.compileLocalID(loc)
+ compileLocalID(loc)
return source._localID
end
---@param source parser.object
---@param key? string
---@return parser.object[]?
-function m.getSources(source, key)
- local id = m.getID(source)
+function vm.getLocalSources(source, key)
+ local id = vm.getLocalID(source)
if not id then
return nil
end
@@ -166,8 +174,8 @@ end
---@param source parser.object
---@return parser.object[]
-function m.getFields(source)
- local id = m.getID(source)
+function vm.getLocalFields(source)
+ local id = vm.getLocalID(source)
if not id then
return nil
end
@@ -195,5 +203,3 @@ function m.getFields(source)
end
return fields
end
-
-return m
diff --git a/script/vm/local-manager.lua b/script/vm/local-manager.lua
deleted file mode 100644
index 51bafb24..00000000
--- a/script/vm/local-manager.lua
+++ /dev/null
@@ -1,40 +0,0 @@
-local util = require 'utility'
-local guide = require 'parser.guide'
-
----@class vm.local-node
-local m = {}
----@type table<uri, parser.object[]>
-m.locals = util.multiTable(2)
----@type table<parser.object, table<parser.object, boolean>>
-m.localSubs = util.multiTable(2, function ()
- return setmetatable({}, util.MODE_K)
-end)
----@type table<parser.object, boolean>
-m.allLocals = {}
-
----@param source parser.object
-function m.declareLocal(source)
- if m.allLocals[source] then
- return
- end
- m.allLocals[source] = true
- local uri = guide.getUri(source)
- local locals = m.locals[uri]
- locals[#locals+1] = source
-end
-
----@param uri uri
-function m.dropUri(uri)
- local locals = m.locals[uri]
- m.locals[uri] = nil
- for _, loc in ipairs(locals) do
- m.allLocals[loc] = nil
- local localSubs = m.localSubs[loc]
- m.localSubs[loc] = nil
- for source in pairs(localSubs) do
- source._node = nil
- end
- end
-end
-
-return m
diff --git a/script/vm/manager.lua b/script/vm/manager.lua
deleted file mode 100644
index 58255fca..00000000
--- a/script/vm/manager.lua
+++ /dev/null
@@ -1,26 +0,0 @@
-
-local files = require 'files'
-local globalManager = require 'vm.global-manager'
-local localManager = require 'vm.local-manager'
-
----@alias vm.object parser.object | vm.global | vm.generic
-
----@class vm.state
-local m = {}
-
-files.watch(function (ev, uri)
- if ev == 'update' then
- globalManager.dropUri(uri)
- localManager.dropUri(uri)
- local state = files.getState(uri)
- if state then
- globalManager.compileAst(state.ast)
- end
- end
- if ev == 'remove' then
- globalManager.dropUri(uri)
- localManager.dropUri(uri)
- end
-end)
-
-return m
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 6906da7e..e76542aa 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -1,5 +1,4 @@
local files = require 'files'
-local localMgr = require 'vm.local-manager'
---@class vm
local vm = require 'vm.vm'
local ws = require 'workspace.workspace'
@@ -8,15 +7,14 @@ local ws = require 'workspace.workspace'
vm.nodeCache = {}
---@class vm.node
+---@field [integer] vm.object
local mt = {}
mt.__index = mt
+mt.id = 0
mt.type = 'vm.node'
mt.optional = nil
mt.lastInfer = nil
mt.data = nil
----@type vm.node[]
-mt._childs = nil
-mt._locked = false
---@param node vm.node | vm.object
function mt:merge(node)
@@ -30,20 +28,10 @@ function mt:merge(node)
if node:isOptional() then
self.optional = true
end
- if node._locked then
- if not self._childs then
- self._childs = {}
- end
- if not self._childs[node] then
- self._childs[#self._childs+1] = node
- self._childs[node] = true
- end
- else
- for _, obj in ipairs(node) do
- if not self[obj] then
- self[obj] = true
- self[#self+1] = obj
- end
+ for _, obj in ipairs(node) do
+ if not self[obj] then
+ self[obj] = true
+ self[#self+1] = obj
end
end
else
@@ -54,84 +42,25 @@ function mt:merge(node)
end
end
-function mt:_each(mark, callback)
- if mark[self] then
- return
- end
- mark[self] = true
- for i = 1, #self do
- callback(self[i])
- end
- local childs = self._childs
- if not childs then
- return
- end
- for i = 1, #childs do
- local child = childs[i]
- if not child:isLocked() then
- child:_each(mark, callback)
- end
- end
-end
-
-function mt:_expand()
- local childs = self._childs
- if not childs then
- return
- end
- self._childs = nil
-
- local mark = {}
- mark[self] = true
-
- local function insert(obj)
- if not self[obj] then
- self[obj] = true
- self[#self+1] = obj
- end
- end
-
- for i = 1, #childs do
- local child = childs[i]
- if child:isLocked() then
- if not self._childs then
- self._childs = {}
- end
- if not self._childs[child] then
- self._childs[#self._childs+1] = child
- self._childs[child] = true
- end
- else
- child:_each(mark, insert)
- end
- end
-end
-
---@return boolean
function mt:isEmpty()
- self:_expand()
return #self == 0
end
+function mt:clear()
+ self.optional = nil
+ for i, c in ipairs(self) do
+ self[i] = nil
+ self[c] = nil
+ end
+end
+
---@param n integer
---@return vm.object?
function mt:get(n)
- self:_expand()
return self[n]
end
-function mt:lock()
- self._locked = true
-end
-
-function mt:unlock()
- self._locked = false
-end
-
-function mt:isLocked()
- return self._locked == true
-end
-
function mt:setData(k, v)
if not self.data then
self.data = {}
@@ -147,49 +76,143 @@ function mt:getData(k)
end
function mt:addOptional()
- if self:isOptional() then
- return self
- end
self.optional = true
end
function mt:removeOptional()
- if not self:isOptional() then
- return self
- end
- self:_expand()
- for i = #self, 1, -1 do
- local n = self[i]
- if n.type == 'nil'
- or (n.type == 'boolean' and n[1] == false)
- or (n.type == 'doc.type.boolean' and n[1] == false) then
- self[i] = self[#self]
- self[#self] = nil
- end
- end
+ self:remove 'nil'
end
---@return boolean
function mt:isOptional()
- if self.optional ~= nil then
- return self.optional
+ return self.optional == true
+end
+
+---@return boolean
+function mt:hasFalsy()
+ if self.optional then
+ return true
end
- self:_expand()
for _, c in ipairs(self) do
if c.type == 'nil'
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
or (c.type == 'boolean' and c[1] == false)
or (c.type == 'doc.type.boolean' and c[1] == false) then
- self.optional = true
return true
end
end
- self.optional = false
return false
end
+---@return boolean
+function mt:isNullable()
+ if self.optional then
+ return true
+ end
+ if #self == 0 then
+ return true
+ end
+ for _, c in ipairs(self) do
+ if c.type == 'nil'
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'any') then
+ return true
+ end
+ end
+ return false
+end
+
+---@return vm.node
+function mt:setTruthy()
+ if self.optional == true then
+ self.optional = nil
+ end
+ local hasBoolean
+ for index = #self, 1, -1 do
+ local c = self[index]
+ if c.type == 'nil'
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
+ or (c.type == 'boolean' and c[1] == false)
+ or (c.type == 'doc.type.boolean' and c[1] == false) then
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
+ end
+ if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean')
+ or (c.type == 'boolean' or c.type == 'doc.type.boolean') then
+ hasBoolean = true
+ table.remove(self, index)
+ self[c] = nil
+ goto CONTINUE
+ end
+ ::CONTINUE::
+ end
+ if hasBoolean then
+ self[#self+1] = vm.declareGlobal('type', 'true')
+ end
+end
+
+---@return vm.node
+function mt:setFalsy()
+ if self.optional == false then
+ self.optional = nil
+ end
+ local hasBoolean
+ for index = #self, 1, -1 do
+ local c = self[index]
+ if c.type == 'nil'
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
+ or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
+ or (c.type == 'boolean' and c[1] == true)
+ or (c.type == 'doc.type.boolean' and c[1] == true) then
+ goto CONTINUE
+ end
+ if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean')
+ or (c.type == 'boolean' or c.type == 'doc.type.boolean') then
+ hasBoolean = true
+ table.remove(self, index)
+ self[c] = nil
+ end
+ ::CONTINUE::
+ end
+ if hasBoolean then
+ self[#self+1] = vm.declareGlobal('type', 'false')
+ end
+end
+
+---@param name string
+function mt:remove(name)
+ if name == 'nil' and self.optional == true then
+ self.optional = nil
+ end
+ for index = #self, 1, -1 do
+ local c = self[index]
+ if (c.type == 'global' and c.cate == 'type' and c.name == name)
+ or (c.type == name)
+ or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer'))
+ or (c.type == 'doc.type.boolean' and name == 'boolean')
+ or (c.type == 'doc.type.table' and name == 'table')
+ or (c.type == 'doc.type.array' and name == 'table')
+ or (c.type == 'doc.type.function' and name == 'function') then
+ table.remove(self, index)
+ self[c] = nil
+ end
+ end
+end
+
+---@param node vm.node
+function mt:removeNode(node)
+ for _, c in ipairs(node) do
+ if c.type == 'global' and c.cate == 'type' then
+ self:remove(c.name)
+ end
+ end
+end
+
---@return fun():vm.object
function mt:eachObject()
- self:_expand()
local i = 0
return function ()
i = i + 1
@@ -197,12 +220,21 @@ function mt:eachObject()
end
end
----@param source parser.object | vm.generic
+---@return vm.node
+function mt:copy()
+ return vm.createNode(self)
+end
+
+---@param source vm.object
---@param node vm.node | vm.object
---@param cover? boolean
function vm.setNode(source, node, cover)
if not node then
- error('Can not set nil node')
+ if TEST then
+ error('Can not set nil node')
+ else
+ log.error('Can not set nil node')
+ end
end
if source.type == 'global' then
error('Can not set node to global')
@@ -216,13 +248,14 @@ function vm.setNode(source, node, cover)
me:merge(node)
else
if node.type == 'vm.node' then
- vm.nodeCache[source] = node
+ vm.nodeCache[source] = node:copy()
else
vm.nodeCache[source] = vm.createNode(node)
end
end
end
+---@param source vm.object
---@return vm.node?
function vm.getNode(source)
return vm.nodeCache[source]
@@ -256,11 +289,16 @@ function vm.clearNodeCache()
vm.nodeCache = {}
end
+local ID = 0
+
---@param a? vm.node | vm.object
---@param b? vm.node | vm.object
---@return vm.node
function vm.createNode(a, b)
- local node = setmetatable({}, mt)
+ ID = ID + 1
+ local node = setmetatable({
+ id = ID,
+ }, mt)
if a then
node:merge(a)
end
diff --git a/script/vm/ref.lua b/script/vm/ref.lua
index 65e8fdab..545c294a 100644
--- a/script/vm/ref.lua
+++ b/script/vm/ref.lua
@@ -2,8 +2,6 @@
local vm = require 'vm.vm'
local util = require 'utility'
local guide = require 'parser.guide'
-local localID = require 'vm.local-id'
-local globalMgr = require 'vm.global-manager'
local files = require 'files'
local await = require 'await'
local progress = require 'progress'
@@ -242,7 +240,7 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
- local idSources = localID.getSources(source)
+ local idSources = vm.getLocalSources(source)
if not idSources then
return
end
@@ -291,7 +289,7 @@ end
---@async
---@param source parser.object
----@param fileNotify fun(uri: uri): boolean
+---@param fileNotify? fun(uri: uri): boolean
function vm.getRefs(source, fileNotify)
local results = {}
local mark = {}
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
new file mode 100644
index 00000000..9fe0f172
--- /dev/null
+++ b/script/vm/runner.lua
@@ -0,0 +1,444 @@
+---@class vm
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+
+---@class vm.runner
+---@field loc parser.object
+---@field mainBlock parser.object
+---@field blocks table<parser.object, true>
+---@field steps vm.runner.step[]
+local mt = {}
+mt.__index = mt
+mt.index = 1
+
+---@class parser.object
+---@field _casts parser.object[]
+
+---@class vm.runner.step
+---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast'
+---@field pos integer
+---@field order? integer
+---@field node? vm.node
+---@field object? parser.object
+---@field name? string
+---@field cast? parser.object
+---@field tag? string
+---@field copy? boolean
+---@field new? boolean
+---@field ref1? vm.runner.step
+---@field ref2? vm.runner.step
+
+---@param filter parser.object
+---@param outStep vm.runner.step
+---@param blockStep vm.runner.step
+function mt:_compileNarrowByFilter(filter, outStep, blockStep)
+ if not filter then
+ return
+ end
+ if filter.type == 'paren' then
+ if filter.exp then
+ self:_compileNarrowByFilter(filter.exp, outStep, blockStep)
+ end
+ return
+ end
+ if filter.type == 'unary' then
+ if not filter.op
+ or not filter[1] then
+ return
+ end
+ if filter.op.type == 'not' then
+ local exp = filter[1]
+ if exp.type == 'getlocal' and exp.node == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'falsy',
+ pos = filter.finish,
+ new = true,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'truthy',
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ end
+ end
+ elseif filter.type == 'binary' then
+ if not filter.op
+ or not filter[1]
+ or not filter[2] then
+ return
+ end
+ if filter.op.type == 'and' then
+ local dummyStep = {
+ type = 'save',
+ copy = true,
+ ref1 = outStep,
+ pos = filter.start - 1,
+ }
+ self.steps[#self.steps+1] = dummyStep
+ self:_compileNarrowByFilter(filter[1], dummyStep, blockStep)
+ self:_compileNarrowByFilter(filter[2], dummyStep, blockStep)
+ end
+ if filter.op.type == 'or' then
+ self:_compileNarrowByFilter(filter[1], outStep, blockStep)
+ local dummyStep = {
+ type = 'push',
+ copy = true,
+ ref1 = outStep,
+ pos = filter.op.finish,
+ }
+ self.steps[#self.steps+1] = dummyStep
+ self:_compileNarrowByFilter(filter[2], outStep, dummyStep)
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'or reset',
+ ref1 = blockStep,
+ pos = filter.finish,
+ }
+ end
+ if filter.op.type == '=='
+ or filter.op.type == '~=' then
+ local loc, exp
+ for i = 1, 2 do
+ loc = filter[i]
+ if loc.type == 'getlocal' and loc.node == self.loc then
+ exp = filter[i % 2 + 1]
+ break
+ end
+ end
+ if not loc or not exp then
+ return
+ end
+ if guide.isLiteral(exp) then
+ if filter.op.type == '==' then
+ self.steps[#self.steps+1] = {
+ type = 'remove',
+ name = exp.type,
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'as',
+ name = exp.type,
+ pos = filter.finish,
+ new = true,
+ }
+ end
+ if filter.op.type == '~=' then
+ self.steps[#self.steps+1] = {
+ type = 'as',
+ name = exp.type,
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'remove',
+ name = exp.type,
+ pos = filter.finish,
+ new = true,
+ }
+ end
+ end
+ end
+ else
+ if filter.type == 'getlocal' and filter.node == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'truthy',
+ pos = filter.finish,
+ new = true,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'falsy',
+ pos = filter.finish,
+ ref1 = outStep,
+ }
+ end
+ end
+end
+
+---@param block parser.object
+function mt:_compileBlock(block)
+ if self.blocks[block] then
+ return
+ end
+ self.blocks[block] = true
+ if block == self.mainBlock then
+ return
+ end
+
+ local parentBlock = guide.getParentBlock(block)
+ self:_compileBlock(parentBlock)
+
+ if block.type == 'if' then
+ ---@type vm.runner.step[]
+ local finals = {}
+ for i, childBlock in ipairs(block) do
+ local blockStep = {
+ type = 'save',
+ tag = 'block',
+ copy = true,
+ pos = childBlock.start,
+ }
+ local outStep = {
+ type = 'save',
+ tag = 'out',
+ copy = true,
+ pos = childBlock.start,
+ }
+ self.steps[#self.steps+1] = blockStep
+ self.steps[#self.steps+1] = outStep
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ ref1 = blockStep,
+ pos = childBlock.start,
+ }
+ self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep)
+ if not childBlock.hasReturn
+ and not childBlock.hasGoTo
+ and not childBlock.hasBreak then
+ local finalStep = {
+ type = 'save',
+ pos = childBlock.finish,
+ tag = 'final #' .. i,
+ }
+ finals[#finals+1] = finalStep
+ self.steps[#self.steps+1] = finalStep
+ end
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'reset child',
+ ref1 = outStep,
+ pos = childBlock.finish,
+ }
+ end
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ tag = 'reset if',
+ pos = block.finish,
+ copy = true,
+ }
+ for _, final in ipairs(finals) do
+ self.steps[#self.steps+1] = {
+ type = 'merge',
+ ref2 = final,
+ pos = block.finish,
+ }
+ end
+ end
+
+ if block.type == 'function'
+ or block.type == 'while'
+ or block.type == 'loop'
+ or block.type == 'in'
+ or block.type == 'repeat'
+ or block.type == 'for' then
+ local savePoint = {
+ type = 'save',
+ copy = true,
+ pos = block.start,
+ }
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ copy = true,
+ pos = block.start,
+ }
+ self.steps[#self.steps+1] = savePoint
+ self.steps[#self.steps+1] = {
+ type = 'push',
+ pos = block.finish,
+ ref1 = savePoint,
+ }
+ end
+end
+
+---@return parser.object[]
+function mt:_getCasts()
+ local root = guide.getRoot(self.loc)
+ if not root._casts then
+ root._casts = {}
+ local docs = root.docs
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.cast' and doc.loc then
+ root._casts[#root._casts+1] = doc
+ end
+ end
+ end
+ return root._casts
+end
+
+function mt:_preCompile()
+ local startPos = self.loc.start
+ local finishPos = 0
+
+ for _, ref in ipairs(self.loc.ref) do
+ self.steps[#self.steps+1] = {
+ type = 'object',
+ object = ref,
+ pos = ref.range or ref.start,
+ }
+ if ref.start > finishPos then
+ finishPos = ref.start
+ end
+ local block = guide.getParentBlock(ref)
+ self:_compileBlock(block)
+ end
+
+ for i, step in ipairs(self.steps) do
+ if step.type ~= 'object' then
+ step.order = i
+ end
+ end
+
+ local casts = self:_getCasts()
+ for _, cast in ipairs(casts) do
+ if cast.loc[1] == self.loc[1]
+ and cast.start > startPos
+ and cast.finish < finishPos
+ and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'cast',
+ cast = cast,
+ pos = cast.start,
+ }
+ end
+ end
+
+ table.sort(self.steps, function (a, b)
+ if a.pos == b.pos then
+ return (a.order or 0) < (b.order or 0)
+ else
+ return a.pos < b.pos
+ end
+ end)
+end
+
+---@param loc parser.object
+---@param node vm.node
+---@return vm.node
+local function checkAssert(loc, node)
+ local parent = loc.parent
+ if parent.type == 'binary' then
+ if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then
+ local exp
+ for i = 1, 2 do
+ if parent[i] == loc then
+ exp = parent[i % 2 + 1]
+ end
+ end
+ if exp and guide.isLiteral(exp) then
+ local callargs = parent.parent
+ if callargs.type == 'callargs'
+ and callargs.parent.node.special == 'assert'
+ and callargs[1] == parent then
+ if parent.op.type == '~=' then
+ node:remove(exp.type)
+ end
+ if parent.op.type == '==' then
+ node = vm.compileNode(exp)
+ end
+ end
+ end
+ end
+ end
+ if parent.type == 'callargs'
+ and parent.parent.node.special == 'assert'
+ and parent[1] == loc then
+ node:setTruthy()
+ end
+ return node
+end
+
+---@param callback fun(src: parser.object, node: vm.node)
+function mt:launch(callback)
+ local topNode = vm.getNode(self.loc):copy()
+ for _, step in ipairs(self.steps) do
+ local node = step.ref1 and step.ref1.node or topNode
+ if step.type == 'truthy' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:setTruthy()
+ elseif step.type == 'falsy' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:setFalsy()
+ elseif step.type == 'as' then
+ if step.new then
+ topNode = vm.createNode(vm.getGlobal('type', step.name))
+ else
+ node:clear()
+ node:merge(vm.getGlobal('type', step.name))
+ end
+ elseif step.type == 'add' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:merge(vm.getGlobal('type', step.name))
+ elseif step.type == 'remove' then
+ if step.new then
+ node = node:copy()
+ topNode = node
+ end
+ node:remove(step.name)
+ elseif step.type == 'object' then
+ topNode = callback(step.object, node) or node
+ if step.object.type == 'getlocal' then
+ topNode = checkAssert(step.object, node)
+ end
+ elseif step.type == 'save' then
+ if step.copy then
+ node = node:copy()
+ end
+ step.node = node
+ elseif step.type == 'push' then
+ if step.copy then
+ node = node:copy()
+ end
+ topNode = node
+ elseif step.type == 'merge' then
+ node:merge(step.ref2.node)
+ elseif step.type == 'cast' then
+ topNode = node:copy()
+ for _, cast in ipairs(step.cast.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ topNode:addOptional()
+ end
+ if cast.extends then
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ topNode:removeOptional()
+ end
+ if cast.extends then
+ topNode:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ topNode:clear()
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ end
+ end
+ end
+ end
+end
+
+---@param loc parser.object
+---@return vm.runner
+function vm.createRunner(loc)
+ local self = setmetatable({
+ loc = loc,
+ mainBlock = guide.getParentBlock(loc),
+ blocks = {},
+ steps = {},
+ }, mt)
+
+ self:_preCompile()
+
+ return self
+end
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 2d45a5a7..fe112bc2 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -1,6 +1,6 @@
local guide = require 'parser.guide'
+---@class vm
local vm = require 'vm.vm'
-local infer = require 'vm.infer'
---@class vm.sign
---@field parent parser.object
@@ -16,12 +16,12 @@ end
---@param uri uri
---@param args parser.object
+---@param removeGeneric true?
---@return table<string, vm.node>
-function mt:resolve(uri, args)
+function mt:resolve(uri, args, removeGeneric)
if not args then
return nil
end
- local globalMgr = require 'vm.global-manager'
local resolved = {}
---@param object parser.object
@@ -33,7 +33,7 @@ function mt:resolve(uri, args)
-- 'number' -> `T`
for n in node:eachObject() do
if n.type == 'string' then
- local type = globalMgr.declareGlobal('type', n[1], guide.getUri(n))
+ local type = vm.declareGlobal('type', n[1], guide.getUri(n))
resolved[key] = vm.createNode(type, resolved[key])
end
end
@@ -48,6 +48,19 @@ function mt:resolve(uri, args)
-- number[] -> T[]
resolve(object.node, vm.compileNode(n.node))
end
+ if n.type == 'doc.type.table' then
+ -- { [integer]: number } -> T[]
+ local tvalueNode = vm.getTableValue(uri, node, 'integer')
+ if tvalueNode then
+ resolve(object.node, tvalueNode)
+ end
+ end
+ if n.type == 'global' and n.cate == 'type' then
+ -- ---@field [integer]: number -> T[]
+ vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), false, function (field)
+ resolve(object.node, vm.compileNode(field.extends))
+ end)
+ end
end
end
if object.type == 'doc.type.table' then
@@ -98,7 +111,7 @@ function mt:resolve(uri, args)
goto CONTINUE
end
end
- local view = infer.viewObject(obj)
+ local view = vm.viewObject(obj)
if view then
knownTypes[view] = true
end
@@ -114,10 +127,10 @@ function mt:resolve(uri, args)
local function buildArgNode(argNode, knownTypes)
local newArgNode = vm.createNode()
for n in argNode:eachObject() do
- if argNode:isOptional() and vm.isFalsy(n) then
+ if argNode:hasFalsy() then
goto CONTINUE
end
- local view = infer.viewObject(n)
+ local view = vm.viewObject(n)
if knownTypes[view] then
goto CONTINUE
end
@@ -156,7 +169,7 @@ function mt:resolve(uri, args)
end
---@return vm.sign
-return function ()
+function vm.createSign()
local genericMgr = setmetatable({
signList = {},
}, mt)
diff --git a/script/vm/type.lua b/script/vm/type.lua
index fa02d19e..c3264993 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -1,4 +1,3 @@
-local globalMgr = require 'vm.global-manager'
---@class vm
local vm = require 'vm.vm'
@@ -9,10 +8,10 @@ local vm = require 'vm.vm'
---@return boolean
function vm.isSubType(uri, child, parent, mark)
if type(parent) == 'string' then
- parent = vm.createNode(globalMgr.getGlobal('type', parent))
+ parent = vm.createNode(vm.getGlobal('type', parent))
end
if type(child) == 'string' then
- child = vm.createNode(globalMgr.getGlobal('type', child))
+ child = vm.createNode(vm.getGlobal('type', child))
end
if not child or not parent then
@@ -134,7 +133,7 @@ function vm.getTableKey(uri, tnode, vnode)
end
end
if tn.type == 'doc.type.array' then
- result:merge(globalMgr.getGlobal('type', 'integer'))
+ result:merge(vm.declareGlobal('type', 'integer'))
end
if tn.type == 'table' then
for _, field in ipairs(tn) do
@@ -144,10 +143,10 @@ function vm.getTableKey(uri, tnode, vnode)
end
end
if field.type == 'tablefield' then
- result:merge(globalMgr.getGlobal('type', 'string'))
+ result:merge(vm.declareGlobal('type', 'string'))
end
if field.type == 'tableexp' then
- result:merge(globalMgr.getGlobal('type', 'integer'))
+ result:merge(vm.declareGlobal('type', 'integer'))
end
end
end
diff --git a/script/vm/value.lua b/script/vm/value.lua
index a784be2a..d29ca9d0 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -17,7 +17,16 @@ function vm.test(source)
hasTrue = true
end
if n[1] == false then
- hasTrue = false
+ hasFalse = true
+ end
+ end
+ if n.type == 'global' and n.cate == 'type' then
+ if n.name == 'true' then
+ hasTrue = true
+ end
+ if n.name == 'false'
+ or n.name == 'nil' then
+ hasFalse = true
end
end
if n.type == 'nil' then
@@ -41,28 +50,9 @@ function vm.test(source)
end
end
----@param source parser.object
----@return boolean
-function vm.isFalsy(source)
- if source.type == 'nil' then
- return true
- end
- if source.type == 'boolean'
- or source.type == 'doc.type.boolean' then
- return source[1] == false
- end
- return false
-end
-
---@param v vm.object
---@return string?
local function getUnique(v)
- if v.type == 'local' then
- return ('loc:%s@%d'):format(guide.getUri(v), v.start)
- end
- if v.type == 'global' then
- return ('%s:%s'):format(v.cate, v.name)
- end
if v.type == 'boolean' then
if v[1] == nil then
return false
diff --git a/script/vm/vm.lua b/script/vm/vm.lua
index 3c1762bf..8117d311 100644
--- a/script/vm/vm.lua
+++ b/script/vm/vm.lua
@@ -23,6 +23,7 @@ function m.getSpecial(source)
return source.special
end
+---@return string?
function m.getKeyName(source)
if not source then
return nil