diff options
Diffstat (limited to 'script/vm')
-rw-r--r-- | script/vm/compiler.lua | 459 | ||||
-rw-r--r-- | script/vm/def.lua | 15 | ||||
-rw-r--r-- | script/vm/doc.lua | 11 | ||||
-rw-r--r-- | script/vm/field.lua | 10 | ||||
-rw-r--r-- | script/vm/generic.lua | 5 | ||||
-rw-r--r-- | script/vm/global-manager.lua | 364 | ||||
-rw-r--r-- | script/vm/global.lua | 431 | ||||
-rw-r--r-- | script/vm/infer.lua | 115 | ||||
-rw-r--r-- | script/vm/init.lua | 10 | ||||
-rw-r--r-- | script/vm/library.lua | 21 | ||||
-rw-r--r-- | script/vm/local-id.lua | 62 | ||||
-rw-r--r-- | script/vm/local-manager.lua | 40 | ||||
-rw-r--r-- | script/vm/manager.lua | 26 | ||||
-rw-r--r-- | script/vm/node.lua | 260 | ||||
-rw-r--r-- | script/vm/ref.lua | 6 | ||||
-rw-r--r-- | script/vm/runner.lua | 444 | ||||
-rw-r--r-- | script/vm/sign.lua | 29 | ||||
-rw-r--r-- | script/vm/type.lua | 11 | ||||
-rw-r--r-- | script/vm/value.lua | 30 | ||||
-rw-r--r-- | script/vm/vm.lua | 1 |
20 files changed, 1511 insertions, 839 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 8126f393..75620d19 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1,10 +1,6 @@ local guide = require 'parser.guide' local util = require 'utility' -local localID = require 'vm.local-id' -local globalMgr = require 'vm.global-manager' -local signMgr = require 'vm.sign' local config = require 'config' -local genericMgr = require 'vm.generic' local rpath = require 'workspace.require-path' local files = require 'files' ---@class vm @@ -13,7 +9,6 @@ local vm = require 'vm.vm' ---@class parser.object ---@field _compiledNodes boolean ---@field _node vm.node ----@field _localBase table ---@field _globalBase table local searchFieldSwitch = util.switch() @@ -54,7 +49,7 @@ local searchFieldSwitch = util.switch() : case 'string' : call(function (suri, source, key, ref, pushResult) -- change to `string: stringlib` ? - local stringlib = globalMgr.getGlobal('type', 'stringlib') + local stringlib = vm.getGlobal('type', 'stringlib') if stringlib then vm.getClassFields(suri, stringlib, key, ref, pushResult) end @@ -64,9 +59,9 @@ local searchFieldSwitch = util.switch() : call(function (suri, node, key, ref, pushResult) local fields if key then - fields = localID.getSources(node, key) + fields = vm.getLocalSources(node, key) else - fields = localID.getFields(node) + fields = vm.getLocalFields(node) end if fields then for _, src in ipairs(fields) do @@ -119,7 +114,7 @@ local searchFieldSwitch = util.switch() if type(key) ~= 'string' then return end - local global = globalMgr.getGlobal('variable', node.name, key) + local global = vm.getGlobal('variable', node.name, key) if global then for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -131,7 +126,7 @@ local searchFieldSwitch = util.switch() end end else - local globals = globalMgr.getFields('variable', node.name) + local globals = vm.getGlobalFields('variable', node.name) for _, global in ipairs(globals) do for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -158,7 +153,7 @@ local searchFieldSwitch = util.switch() if type(key) ~= 'string' then return end - local global = globalMgr.getGlobal('variable', node.name, key) + local global = vm.getGlobal('variable', node.name, key) if global then for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -168,7 +163,7 @@ local searchFieldSwitch = util.switch() end end else - local globals = globalMgr.getFields('variable', node.name) + local globals = vm.getGlobalFields('variable', node.name) for _, global in ipairs(globals) do for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -185,7 +180,7 @@ local searchFieldSwitch = util.switch() end) -function vm.getClassFields(suri, node, key, ref, pushResult) +function vm.getClassFields(suri, object, key, ref, pushResult) local mark = {} local function searchClass(class, searchedFields) @@ -201,11 +196,51 @@ function vm.getClassFields(suri, node, key, ref, pushResult) local hasFounded = {} for _, field in ipairs(set.fields) do local fieldKey = guide.getKeyName(field) - if key == nil - or fieldKey == key then - if not searchedFields[fieldKey] then - pushResult(field) - hasFounded[fieldKey] = true + if fieldKey then + -- ---@field x boolean -> class.x + if key == nil + or fieldKey == key then + if not searchedFields[fieldKey] then + pushResult(field) + hasFounded[fieldKey] = true + end + end + end + if not hasFounded[fieldKey] then + local keyType = type(key) + if keyType == 'table' then + -- ---@field [integer] boolean -> class[integer] + local fieldNode = vm.compileNode(field.field) + if vm.isSubType(suri, key.name, fieldNode) then + local nkey = '|' .. key.name + if not searchedFields[nkey] then + pushResult(field) + hasFounded[nkey] = true + end + end + else + local typeName + if keyType == 'number' then + if math.tointeger(key) then + typeName = 'integer' + else + typeName = 'number' + end + elseif keyType == 'boolean' + or keyType == 'string' then + typeName = keyType + end + if typeName then + -- ---@field [integer] boolean -> class[1] + local fieldNode = vm.compileNode(field.field) + if vm.isSubType(suri, typeName, fieldNode) then + local nkey = '|' .. typeName + if not searchedFields[nkey] then + pushResult(field) + hasFounded[nkey] = true + end + end + end end end end @@ -214,19 +249,23 @@ function vm.getClassFields(suri, node, key, ref, pushResult) for _, src in ipairs(set.bindSources) do searchFieldSwitch(src.type, suri, src, key, ref, function (field) local fieldKey = guide.getKeyName(field) - if not searchedFields[fieldKey] - and guide.isSet(field) then - hasFounded[fieldKey] = true - pushResult(field) + if fieldKey then + if not searchedFields[fieldKey] + and guide.isSet(field) then + hasFounded[fieldKey] = true + pushResult(field) + end end end) if src.value and src.value.type == 'table' then searchFieldSwitch('table', suri, src.value, key, ref, function (field) local fieldKey = guide.getKeyName(field) - if not searchedFields[fieldKey] - and guide.isSet(field) then - hasFounded[fieldKey] = true - pushResult(field) + if fieldKey then + if not searchedFields[fieldKey] + and guide.isSet(field) then + hasFounded[fieldKey] = true + pushResult(field) + end end end) end @@ -239,7 +278,7 @@ function vm.getClassFields(suri, node, key, ref, pushResult) end for _, extend in ipairs(set.extends) do if extend.type == 'doc.extends.name' then - local extendType = globalMgr.getGlobal('type', extend[1]) + local extendType = vm.getGlobal('type', extend[1]) if extendType then searchClass(extendType, searchedFields) end @@ -253,12 +292,12 @@ function vm.getClassFields(suri, node, key, ref, pushResult) local function searchGlobal(class) if class.cate == 'type' and class.name == '_G' then if key == nil then - local sets = globalMgr.getGlobalSets(suri, 'variable') + local sets = vm.getGlobalSets(suri, 'variable') for _, set in ipairs(sets) do pushResult(set) end else - local global = globalMgr.getGlobal('variable', key) + local global = vm.getGlobal('variable', key) if global then for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -268,8 +307,8 @@ function vm.getClassFields(suri, node, key, ref, pushResult) end end - searchClass(node) - searchGlobal(node) + searchClass(object) + searchGlobal(object) end ---@class parser.object @@ -283,10 +322,13 @@ local function getObjectSign(source) end source._sign = false if source.type == 'function' then + if not source.bindDocs then + return false + end for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.generic' then if not source._sign then - source._sign = signMgr() + source._sign = vm.createSign() break end end @@ -314,14 +356,18 @@ local function getObjectSign(source) if not hasGeneric then return false end - source._sign = signMgr() + source._sign = vm.createSign() if source.type == 'doc.type.function' then for _, arg in ipairs(source.args) do - local argNode = vm.compileNode(arg.extends) - if arg.optional then - argNode:addOptional() + if arg.extends then + local argNode = vm.compileNode(arg.extends) + if arg.optional then + argNode:addOptional() + end + source._sign:addSign(argNode) + else + source._sign:addSign(vm.createNode()) end - source._sign:addSign(argNode) end end end @@ -354,7 +400,7 @@ function vm.getReturnOfFunction(func, index) if not sign then return rtn end - return genericMgr(rtn, sign) + return vm.createGeneric(rtn, sign) end end @@ -455,6 +501,9 @@ local function getReturn(func, index, args) result:merge(rnode) end end + if result and returnNode:isOptional() then + result:addOptional() + end end end end @@ -462,6 +511,25 @@ local function getReturn(func, index, args) return result end +---@param source parser.object +---@return boolean +local function bindAs(source) + local root = guide.getRoot(source) + local docs = root.docs + if not docs then + return + end + for _, doc in ipairs(docs) do + if doc.type == 'doc.as' and doc.originalComment.start == source.finish + 2 then + if doc.as then + vm.setNode(source, vm.compileNode(doc.as), true) + end + return true + end + end + return false +end + local function bindDocs(source) local isParam = source.parent.type == 'funcargs' or source.parent.type == 'in' @@ -485,7 +553,11 @@ local function bindDocs(source) end if doc.type == 'doc.param' then if isParam and source[1] == doc.param[1] then - vm.setNode(source, vm.compileNode(doc)) + local node = vm.compileNode(doc) + if doc.optional then + node:addOptional() + end + vm.setNode(source, node) return true end end @@ -503,12 +575,17 @@ local function bindDocs(source) vm.setNode(source, vm.compileNode(ast)) return true end + if doc.type == 'doc.overload' then + if not isParam then + vm.setNode(source, vm.compileNode(doc)) + end + end end return false end local function compileByLocalID(source) - local sources = localID.getSources(source) + local sources = vm.getLocalSources(source) if not sources then return end @@ -571,7 +648,7 @@ local function selectNode(source, list, index) if exp.type == 'call' then result = getReturn(exp.node, index, exp.args) if not result then - vm.setNode(source, globalMgr.getGlobal('type', 'unknown')) + vm.setNode(source, vm.declareGlobal('type', 'unknown')) return vm.getNode(source) end else @@ -597,7 +674,7 @@ local function selectNode(source, list, index) end end if not hasKnownType then - rtnNode:merge(globalMgr.getGlobal('type', 'unknown')) + rtnNode:merge(vm.declareGlobal('type', 'unknown')) end vm.setNode(source, rtnNode) return rtnNode @@ -664,10 +741,21 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) for n in callNode:eachObject() do if n.type == 'function' then + local sign = getObjectSign(n) local farg = getFuncArg(n, myIndex) if farg then for fn in vm.compileNode(farg):eachObject() do if isValidCallArgNode(arg, fn) then + if fn.type == 'doc.type.function' then + if sign then + local generic = vm.createGeneric(fn, sign) + local args = {} + for i = fixIndex + 1, myIndex - 1 do + args[#args+1] = call.args[i] + end + fn = generic:resolve(guide.getUri(call), args) + end + end vm.setNode(arg, fn) end end @@ -716,29 +804,19 @@ function vm.compileCallArg(arg, call, index) if call.node.special == 'pcall' or call.node.special == 'xpcall' then local fixIndex = call.node.special == 'pcall' and 1 or 2 - callNode = vm.compileNode(call.args[1]) - compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex) + if call.args and call.args[1] then + callNode = vm.compileNode(call.args[1]) + compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex) + end end return vm.getNode(arg) end ---@param source parser.object ---@return vm.node -local function compileLocalBase(source) - if not source._localBase then - source._localBase = { - type = 'localbase', - parent = source, - } - end - local baseNode = vm.getNode(source._localBase) - if baseNode then - return baseNode - end - baseNode = vm.createNode() - vm.setNode(source._localBase, baseNode, true) - +local function compileLocal(source) vm.setNode(source, source) + local hasMarkDoc if source.bindDocs then hasMarkDoc = bindDocs(source) @@ -788,14 +866,19 @@ local function compileLocalBase(source) if n.type == 'doc.type.function' then for index, arg in ipairs(n.args) do if func.args[index] == source then - vm.setNode(source, vm.compileNode(arg)) + local argNode = vm.compileNode(arg) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + end + end hasDocArg = true end end end end if not hasDocArg then - vm.setNode(source, globalMgr.getGlobal('type', 'any')) + vm.setNode(source, vm.declareGlobal('type', 'any')) end end -- for x in ... do @@ -805,15 +888,10 @@ local function compileLocalBase(source) -- for x = ... do if source.parent.type == 'loop' then - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.compileNode(source.parent) end - baseNode:merge(vm.getNode(source)) - vm.removeNode(source) - - baseNode:setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue) - - return baseNode + vm.getNode(source):setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue) end local compilerSwitch = util.switch() @@ -867,41 +945,79 @@ local compilerSwitch = util.switch() end) : case 'paren' : call(function (source) + if bindAs(source) then + return + end if source.exp then vm.setNode(source, vm.compileNode(source.exp)) end end) : case 'local' : case 'self' + ---@param source parser.object : call(function (source) - local baseNode = compileLocalBase(source) - vm.setNode(source, baseNode, true) - if not baseNode:getData 'hasDefined' and source.ref then + compileLocal(source) + local refs = source.ref + if not refs then + return + end + + local hasMark = vm.getNode(source):getData 'hasDefined' + + local runner = vm.createRunner(source) + runner:launch(function (src, node) + if src.type == 'setlocal' then + if src.bindDocs then + for _, doc in ipairs(src.bindDocs) do + if doc.type == 'doc.type' then + vm.setNode(src, vm.compileNode(doc), true) + return vm.getNode(src) + end + end + end + if src.value and guide.isLiteral(src.value) then + if src.value.type == 'table' then + vm.setNode(src, vm.createNode(src.value), true) + else + vm.setNode(src, vm.compileNode(src.value), true) + end + elseif src.value + and src.value.type == 'binary' + and src.value.op and src.value.op.type == 'or' + and src.value[1] and src.value[1].type == 'getlocal' and src.value[1].node == source then + -- x = x or 1 + vm.setNode(src, vm.compileNode(src.value)) + else + vm.setNode(src, node, true) + end + return vm.getNode(src) + elseif src.type == 'getlocal' then + if bindAs(src) then + return + end + vm.setNode(src, node, true) + end + end) + + if not hasMark then + local parentFunc = guide.getParentFunction(source) for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' then - vm.setNode(source, vm.compileNode(ref)) + if ref.type == 'setlocal' + and guide.getParentFunction(ref) == parentFunc then + vm.setNode(source, vm.getNode(ref)) end end end end) : case 'setlocal' : call(function (source) - local baseNode = compileLocalBase(source.node) - if not baseNode:getData 'hasDefined' and source.value then - if source.value.type == 'table' then - vm.setNode(source, source.value) - else - vm.setNode(source, vm.compileNode(source.value)) - end - end - baseNode:merge(vm.getNode(source)) - vm.setNode(source, baseNode, true) vm.compileNode(source.node) end) : case 'getlocal' : call(function (source) - local baseNode = compileLocalBase(source.node) - vm.setNode(source, baseNode, true) + if bindAs(source) then + return + end vm.compileNode(source.node) end) : case 'setfield' @@ -924,6 +1040,9 @@ local compilerSwitch = util.switch() : case 'getmethod' : case 'getindex' : call(function (source) + if bindAs(source) then + return + end compileByLocalID(source) local key = guide.getKeyName(source) if key == nil and source.index then @@ -959,6 +1078,9 @@ local compilerSwitch = util.switch() end) : case 'getglobal' : call(function (source) + if bindAs(source) then + return + end if source.node[1] ~= '_ENV' then return end @@ -1019,7 +1141,7 @@ local compilerSwitch = util.switch() end) end if hasGeneric then - vm.setNode(source, genericMgr(rtn, sign)) + vm.setNode(source, vm.createGeneric(rtn, sign)) else vm.setNode(source, vm.compileNode(rtn)) end @@ -1092,29 +1214,44 @@ local compilerSwitch = util.switch() -- for k, v in pairs(t) do --> for k, v in iterator, status, initValue do --> local k, v = iterator(status, initValue) - source._iterator = {} - source._iterArgs = {{}, {}} - -- iterator - selectNode(source._iterator, source.exps, 1) - -- status - selectNode(source._iterArgs[1], source.exps, 2) - -- initValue - selectNode(source._iterArgs[2], source.exps, 3) - end + source._iterator = { + type = 'dummyfunc', + parent = source, + } + source._iterArgs = {{},{}} + end + -- iterator + selectNode(source._iterator, source.exps, 1) + -- status + selectNode(source._iterArgs[1], source.exps, 2) + -- initValue + selectNode(source._iterArgs[2], source.exps, 3) if source.keys then for i, loc in ipairs(source.keys) do local node = getReturn(source._iterator, i, source._iterArgs) if node then + if i == 1 then + node:removeOptional() + end vm.setNode(loc, node) end end end end) + : case 'loop' + : call(function (source) + if source.loc then + vm.setNode(source.loc, vm.declareGlobal('type', 'integer')) + end + end) : case 'doc.type' : call(function (source) for _, typeUnit in ipairs(source.types) do vm.setNode(source, vm.compileNode(typeUnit)) end + if source.optional then + vm.getNode(source):addOptional() + end end) : case 'doc.type.integer' : case 'doc.type.string' @@ -1130,7 +1267,13 @@ local compilerSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) vm.setNode(source, source) - local global = globalMgr.getGlobal('type', source.node[1]) + if not source.node[1] then + return + end + local global = vm.getGlobal('type', source.node[1]) + if not global then + return + end for _, set in ipairs(global:getSets(uri)) do if set.type == 'doc.class' then if set.extends then @@ -1161,14 +1304,22 @@ local compilerSwitch = util.switch() if not source.extends then return end - vm.setNode(source, vm.compileNode(source.extends)) + local fieldNode = vm.compileNode(source.extends) + if source.optional then + fieldNode:addOptional() + end + vm.setNode(source, fieldNode) end) : case 'doc.type.field' : call(function (source) if not source.extends then return end - vm.setNode(source, vm.compileNode(source.extends)) + local fieldNode = vm.compileNode(source.extends) + if source.optional then + fieldNode:addOptional() + end + vm.setNode(source, fieldNode) end) : case 'doc.param' : call(function (source) @@ -1208,7 +1359,7 @@ local compilerSwitch = util.switch() end) : case 'doc.see.name' : call(function (source) - local type = globalMgr.getGlobal('type', source[1]) + local type = vm.getGlobal('type', source[1]) if type then vm.setNode(source, vm.compileNode(type)) end @@ -1218,7 +1369,10 @@ local compilerSwitch = util.switch() if source.extends then vm.setNode(source, vm.compileNode(source.extends)) else - vm.setNode(source, globalMgr.getGlobal('type', 'any')) + vm.setNode(source, vm.declareGlobal('type', 'any')) + end + if source.optional then + vm.getNode(source):addOptional() end end) : case 'generic' @@ -1227,10 +1381,16 @@ local compilerSwitch = util.switch() end) : case 'unary' : call(function (source) + if bindAs(source) then + return + end + if not source[1] then + return + end if source.op.type == 'not' then local result = vm.test(source[1]) if result == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'boolean')) + vm.setNode(source, vm.declareGlobal('type', 'boolean')) return else vm.setNode(source, { @@ -1244,13 +1404,13 @@ local compilerSwitch = util.switch() end end if source.op.type == '#' then - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end if source.op.type == '-' then local v = vm.getNumber(source[1]) if v == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return else vm.setNode(source, { @@ -1266,7 +1426,7 @@ local compilerSwitch = util.switch() if source.op.type == '~' then local v = vm.getInteger(source[1]) if v == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return else vm.setNode(source, { @@ -1282,34 +1442,42 @@ local compilerSwitch = util.switch() end) : case 'binary' : call(function (source) + if bindAs(source) then + return + end + if not source[1] or not source[2] then + return + end if source.op.type == 'and' then + local node1 = vm.compileNode(source[1]) + local node2 = vm.compileNode(source[2]) local r1 = vm.test(source[1]) if r1 == true then - vm.setNode(source, vm.compileNode(source[2])) - return - end - if r1 == false then - vm.setNode(source, vm.compileNode(source[1])) - return + vm.setNode(source, node2) + elseif r1 == false then + vm.setNode(source, node1) + else + vm.setNode(source, node2) end - return end if source.op.type == 'or' then + local node1 = vm.compileNode(source[1]) + local node2 = vm.compileNode(source[2]) local r1 = vm.test(source[1]) if r1 == true then - vm.setNode(source, vm.compileNode(source[1])) - return - end - if r1 == false then - vm.setNode(source, vm.compileNode(source[2])) - return + vm.setNode(source, node1) + elseif r1 == false then + vm.setNode(source, node2) + else + vm.getNode(source):merge(node1) + vm.getNode(source):setTruthy() + vm.getNode(source):merge(node2) end - return end if source.op.type == '==' then local result = vm.equal(source[1], source[2]) if result == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'boolean')) + vm.setNode(source, vm.declareGlobal('type', 'boolean')) return else vm.setNode(source, { @@ -1325,7 +1493,7 @@ local compilerSwitch = util.switch() if source.op.type == '~=' then local result = vm.equal(source[1], source[2]) if result == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'boolean')) + vm.setNode(source, vm.declareGlobal('type', 'boolean')) return else vm.setNode(source, { @@ -1351,7 +1519,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1368,7 +1536,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1385,7 +1553,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1402,7 +1570,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1419,7 +1587,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1437,7 +1605,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1455,7 +1623,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1473,7 +1641,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1490,14 +1658,14 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end if source.op.type == '%' then local a = vm.getNumber(source[1]) local b = vm.getNumber(source[2]) - if a and b then + if a and b and b ~= 0 then local result = a % b vm.setNode(source, { type = math.type(result) == 'integer' and 'integer' or 'number', @@ -1508,7 +1676,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1525,7 +1693,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1543,7 +1711,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1580,7 +1748,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'string')) + vm.setNode(source, vm.declareGlobal('type', 'string')) return end end @@ -1614,17 +1782,20 @@ local function compileByGlobal(source) vm.setNode(source, globalNode, true) return end + ---@type vm.node globalNode = vm.createNode(global) vm.setNode(root._globalBase[name], globalNode, true) + vm.setNode(source, globalNode, true) - local sets = global.links[uri].sets or {} - local gets = global.links[uri].gets or {} - for _, set in ipairs(sets) do - vm.setNode(set, globalNode, true) - end - for _, get in ipairs(gets) do - vm.setNode(get, globalNode, true) - end + -- TODO:don't mix + --local sets = global.links[uri].sets or {} + --local gets = global.links[uri].gets or {} + --for _, set in ipairs(sets) do + -- vm.setNode(set, globalNode, true) + --end + --for _, get in ipairs(gets) do + -- vm.setNode(get, globalNode, true) + --end if global.cate == 'variable' then local hasMarkDoc @@ -1672,7 +1843,11 @@ end ---@return vm.node function vm.compileNode(source) if not source then - error('Can not compile nil node') + if TEST then + error('Can not compile nil source') + else + log.error('Can not compile nil source') + end end if source.type == 'global' then diff --git a/script/vm/def.lua b/script/vm/def.lua index b66e8fda..83e92686 100644 --- a/script/vm/def.lua +++ b/script/vm/def.lua @@ -2,8 +2,6 @@ local vm = require 'vm.vm' local util = require 'utility' local guide = require 'parser.guide' -local localID = require 'vm.local-id' -local globalMgr = require 'vm.global-manager' local simpleSwitch @@ -79,6 +77,13 @@ simpleSwitch = util.switch() pushResult(source.node) end end) + : case 'doc.cast.name' + : call(function (source, pushResult) + local loc = guide.getLocal(source, source[1], source.start) + if loc then + pushResult(loc) + end + end) local searchFieldSwitch = util.switch() : case 'table' @@ -97,7 +102,7 @@ local searchFieldSwitch = util.switch() ---@param key string : call(function (suri, obj, key, pushResult) if obj.cate == 'variable' then - local newGlobal = globalMgr.getGlobal('variable', obj.name, key) + local newGlobal = vm.getGlobal('variable', obj.name, key) if newGlobal then for _, set in ipairs(newGlobal:getSets(suri)) do pushResult(set) @@ -110,7 +115,7 @@ local searchFieldSwitch = util.switch() end) : case 'local' : call(function (suri, obj, key, pushResult) - local sources = localID.getSources(obj, key) + local sources = vm.getLocalSources(obj, key) if sources then for _, src in ipairs(sources) do if guide.isSet(src) then @@ -189,7 +194,7 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchByLocalID(source, pushResult) - local idSources = localID.getSources(source) + local idSources = vm.getLocalSources(source) if not idSources then return end diff --git a/script/vm/doc.lua b/script/vm/doc.lua index 5a92a103..e2b383b6 100644 --- a/script/vm/doc.lua +++ b/script/vm/doc.lua @@ -3,7 +3,6 @@ local guide = require 'parser.guide' ---@class vm local vm = require 'vm.vm' local config = require 'config' -local globalMgr = require 'vm.global-manager' ---获取class与alias ---@param suri uri @@ -11,13 +10,13 @@ local globalMgr = require 'vm.global-manager' ---@return parser.object[] function vm.getDocSets(suri, name) if name then - local global = globalMgr.getGlobal('type', name) + local global = vm.getGlobal('type', name) if not global then return {} end return global:getSets(suri) else - return globalMgr.getGlobalSets(suri, 'type') + return vm.getGlobalSets(suri, 'type') end end @@ -27,6 +26,9 @@ function vm.isMetaFile(uri) return false end local cache = files.getCache(uri) + if not cache then + return false + end if cache.isMeta ~= nil then return cache.isMeta end @@ -332,6 +334,9 @@ function vm.isDiagDisabledAt(uri, position, name) return false end local cache = files.getCache(uri) + if not cache then + return false + end if not cache.diagnosticRanges then cache.diagnosticRanges = {} for _, doc in ipairs(status.ast.docs) do diff --git a/script/vm/field.lua b/script/vm/field.lua index ba7cd4c1..5de838be 100644 --- a/script/vm/field.lua +++ b/script/vm/field.lua @@ -15,6 +15,15 @@ local searchByNodeSwitch = util.switch() pushResult(source) end) +local function searchByLocalID(source, pushResult) + local fields = vm.getLocalFields(source) + if fields then + for _, field in ipairs(fields) do + pushResult(field) + end + end +end + local function searchByNode(source, pushResult) local uri = guide.getUri(source) vm.compileByParentNode(source, nil, true, function (field) @@ -35,6 +44,7 @@ function vm.getFields(source) end end + searchByLocalID(source, pushResult) searchByNode(source, pushResult) return results diff --git a/script/vm/generic.lua b/script/vm/generic.lua index b3981ff8..6462028e 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -1,3 +1,4 @@ +---@class vm local vm = require 'vm.vm' ---@class parser.object @@ -114,7 +115,7 @@ end ---@param uri uri ---@param args parser.object ----@return parser.object +---@return vm.node function mt:resolve(uri, args) local resolved = self.sign:resolve(uri, args) local protoNode = vm.compileNode(self.proto) @@ -129,7 +130,7 @@ end ---@param proto vm.object ---@param sign vm.sign ---@return vm.generic -return function (proto, sign) +function vm.createGeneric(proto, sign) local generic = setmetatable({ sign = sign, proto = proto, diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua deleted file mode 100644 index f25bb5a0..00000000 --- a/script/vm/global-manager.lua +++ /dev/null @@ -1,364 +0,0 @@ -local util = require 'utility' -local guide = require 'parser.guide' -local globalBuilder = require 'vm.global' -local signMgr = require 'vm.sign' -local genericMgr = require 'vm.generic' ----@class vm -local vm = require 'vm.vm' - ----@class parser.object ----@field _globalNode vm.global - ----@class vm.global-manager -local m = {} ----@type table<string, vm.global> -m.globals = {} ----@type table<uri, table<string, boolean>> -m.globalSubs = util.multiTable(2) - -local compilerGlobalSwitch = util.switch() - : case 'local' - : call(function (source) - if source.special ~= '_G' then - return - end - if source.ref then - for _, ref in ipairs(source.ref) do - m.compileObject(ref) - end - end - end) - : case 'getlocal' - : call(function (source) - if source.special ~= '_G' then - return - end - if not source.next then - return - end - m.compileObject(source.next) - end) - : case 'setglobal' - : call(function (source) - local uri = guide.getUri(source) - local name = guide.getKeyName(source) - local global = m.declareGlobal('variable', name, uri) - global:addSet(uri, source) - source._globalNode = global - end) - : case 'getglobal' - : call(function (source) - local uri = guide.getUri(source) - local name = guide.getKeyName(source) - local global = m.declareGlobal('variable', name, uri) - global:addGet(uri, source) - source._globalNode = global - - local nxt = source.next - if nxt then - m.compileObject(nxt) - end - end) - : case 'setfield' - : case 'setmethod' - : case 'setindex' - ---@param source parser.object - : call(function (source) - local name - local keyName = guide.getKeyName(source) - if not keyName then - return - end - if source.node._globalNode then - local parentName = source.node._globalNode:getName() - if parentName == '_G' then - name = keyName - else - name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName) - end - elseif source.node.special == '_G' then - name = keyName - end - if not name then - return - end - local uri = guide.getUri(source) - local global = m.declareGlobal('variable', name, uri) - global:addSet(uri, source) - source._globalNode = global - end) - : case 'getfield' - : case 'getmethod' - : case 'getindex' - ---@param source parser.object - : call(function (source) - local name - local keyName = guide.getKeyName(source) - if not keyName then - return - end - if source.node._globalNode then - local parentName = source.node._globalNode:getName() - if parentName == '_G' then - name = keyName - else - name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName) - end - elseif source.node.special == '_G' then - name = keyName - end - local uri = guide.getUri(source) - local global = m.declareGlobal('variable', name, uri) - global:addGet(uri, source) - source._globalNode = global - - local nxt = source.next - if nxt then - m.compileObject(nxt) - end - end) - : case 'call' - : call(function (source) - if source.node.special == 'rawset' - or source.node.special == 'rawget' then - if not source.args then - return - end - local g = source.args[1] - local key = source.args[2] - if g and key and g.special == '_G' then - local name = guide.getKeyName(key) - if name then - local uri = guide.getUri(source) - local global = m.declareGlobal('variable', name, uri) - if source.node.special == 'rawset' then - global:addSet(uri, source) - source.value = source.args[3] - else - global:addGet(uri, source) - end - source._globalNode = global - - local nxt = source.next - if nxt then - m.compileObject(nxt) - end - end - end - end - end) - : case 'doc.class' - ---@param source parser.object - : call(function (source) - local uri = guide.getUri(source) - local name = guide.getKeyName(source) - local class = m.declareGlobal('type', name, uri) - class:addSet(uri, source) - source._globalNode = class - - if source.signs then - source._sign = signMgr() - for _, sign in ipairs(source.signs) do - source._sign:addSign(vm.compileNode(sign)) - end - if source.extends then - for _, ext in ipairs(source.extends) do - if ext.type == 'doc.type.table' then - ext._generic = genericMgr(ext, source._sign) - end - end - end - end - end) - : case 'doc.alias' - : call(function (source) - local uri = guide.getUri(source) - local name = guide.getKeyName(source) - local alias = m.declareGlobal('type', name, uri) - alias:addSet(uri, source) - source._globalNode = alias - - if source.signs then - source._sign = signMgr() - for _, sign in ipairs(source.signs) do - source._sign:addSign(vm.compileNode(sign)) - end - source.extends._generic = genericMgr(source.extends, source._sign) - end - end) - : case 'doc.type.name' - : call(function (source) - local uri = guide.getUri(source) - local name = source[1] - local type = m.declareGlobal('type', name, uri) - type:addGet(uri, source) - source._globalNode = type - end) - : case 'doc.extends.name' - : call(function (source) - local uri = guide.getUri(source) - local name = source[1] - local class = m.declareGlobal('type', name, uri) - class:addGet(uri, source) - source._globalNode = class - end) - - ----@alias vm.global.cate '"variable"' | '"type"' - ----@param cate vm.global.cate ----@param name string ----@param uri uri ----@return vm.global -function m.declareGlobal(cate, name, uri) - local key = cate .. '|' .. name - m.globalSubs[uri][key] = true - if not m.globals[key] then - m.globals[key] = globalBuilder(name, cate) - end - return m.globals[key] -end - ----@param cate vm.global.cate ----@param name string ----@param field? string ----@return vm.global? -function m.getGlobal(cate, name, field) - local key = cate .. '|' .. name - if field then - key = key .. vm.ID_SPLITE .. field - end - return m.globals[key] -end - ----@param cate vm.global.cate ----@param name string ----@return vm.global[] -function m.getFields(cate, name) - local globals = {} - local key = cate .. '|' .. name - - -- TODO: optimize - local clock = os.clock() - for gid, global in pairs(m.globals) do - if gid ~= key - and util.stringStartWith(gid, key) - and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE - and not gid:find(vm.ID_SPLITE, #key + 2) then - globals[#globals+1] = global - end - end - local cost = os.clock() - clock - if cost > 0.1 then - log.warn('global-manager getFields cost %.3f', cost) - end - - return globals -end - ----@param cate vm.global.cate ----@return vm.global[] -function m.getGlobals(cate) - local globals = {} - - -- TODO: optimize - local clock = os.clock() - for gid, global in pairs(m.globals) do - if util.stringStartWith(gid, cate) - and not gid:find(vm.ID_SPLITE) then - globals[#globals+1] = global - end - end - local cost = os.clock() - clock - if cost > 0.1 then - log.warn('global-manager getGlobals cost %.3f', cost) - end - - return globals -end - ----@param suri uri ----@param cate vm.global.cate ----@return parser.object[] -function m.getGlobalSets(suri, cate) - local globals = m.getGlobals(cate) - local result = {} - for _, global in ipairs(globals) do - local sets = global:getSets(suri) - for _, set in ipairs(sets) do - result[#result+1] = set - end - end - return result -end - ----@param suri uri ----@param cate vm.global.cate ----@param name string ----@return boolean -function m.hasGlobalSets(suri, cate, name) - local global = m.getGlobal(cate, name) - if not global then - return false - end - local sets = global:getSets(suri) - if #sets == 0 then - return false - end - return true -end - ----@param source parser.object -function m.compileObject(source) - if source._globalNode ~= nil then - return - end - source._globalNode = false - compilerGlobalSwitch(source.type, source) -end - ----@param source parser.object -function m.compileAst(source) - local env = guide.getENV(source) - m.compileObject(env) - guide.eachSpecialOf(source, 'rawset', function (src) - m.compileObject(src.parent) - end) - guide.eachSpecialOf(source, 'rawget', function (src) - m.compileObject(src.parent) - end) - guide.eachSourceTypes(source.docs, { - 'doc.class', - 'doc.alias', - 'doc.type.name', - 'doc.extends.name', - }, function (src) - m.compileObject(src) - end) -end - ----@return vm.global -function m.getNode(source) - if source.type == 'field' - or source.type == 'method' then - source = source.parent - end - return source._globalNode -end - ----@param uri uri -function m.dropUri(uri) - local globalSub = m.globalSubs[uri] - m.globalSubs[uri] = nil - for key in pairs(globalSub) do - local global = m.globals[key] - if global then - global:dropUri(uri) - if not global:isAlive() then - m.globals[key] = nil - end - end - end -end - -return m diff --git a/script/vm/global.lua b/script/vm/global.lua index 1c46c9a3..a54ab552 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -1,5 +1,9 @@ -local util = require 'utility' -local scope= require 'workspace.scope' +local util = require 'utility' +local scope = require 'workspace.scope' +local guide = require 'parser.guide' +local files = require 'files' +---@class vm +local vm = require 'vm.vm' ---@class vm.global.link ---@field gets parser.object[] @@ -15,8 +19,6 @@ mt.__index = mt mt.type = 'global' mt.name = '' -local ID_SPLITE = '\x1F' - ---@param uri uri ---@param source parser.object function mt:addSet(uri, source) @@ -106,7 +108,7 @@ end ---@return string function mt:getKeyName() - return self.name:match('[^' .. ID_SPLITE .. ']+$') + return self.name:match('[^' .. vm.ID_SPLITE .. ']+$') end ---@return boolean @@ -116,10 +118,427 @@ end ---@param cate vm.global.cate ---@return vm.global -return function (name, cate) +local function createGlobal(name, cate) return setmetatable({ name = name, cate = cate, links = util.multiTable(2), }, mt) end + +---@class parser.object +---@field _globalNode vm.global + +---@type table<string, vm.global> +local allGlobals = {} +---@type table<uri, table<string, boolean>> +local globalSubs = util.multiTable(2) + +local compileObject +local compilerGlobalSwitch = util.switch() + : case 'local' + : call(function (source) + if source.special ~= '_G' then + return + end + if source.ref then + for _, ref in ipairs(source.ref) do + compileObject(ref) + end + end + end) + : case 'getlocal' + : call(function (source) + if source.special ~= '_G' then + return + end + if not source.next then + return + end + compileObject(source.next) + end) + : case 'setglobal' + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + local global = vm.declareGlobal('variable', name, uri) + global:addSet(uri, source) + source._globalNode = global + end) + : case 'getglobal' + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + local global = vm.declareGlobal('variable', name, uri) + global:addGet(uri, source) + source._globalNode = global + + local nxt = source.next + if nxt then + compileObject(nxt) + end + end) + : case 'setfield' + : case 'setmethod' + : case 'setindex' + ---@param source parser.object + : call(function (source) + local name + local keyName = guide.getKeyName(source) + if not keyName then + return + end + if source.node._globalNode then + local parentName = source.node._globalNode:getName() + if parentName == '_G' then + name = keyName + else + name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName) + end + elseif source.node.special == '_G' then + name = keyName + end + if not name then + return + end + local uri = guide.getUri(source) + local global = vm.declareGlobal('variable', name, uri) + global:addSet(uri, source) + source._globalNode = global + end) + : case 'getfield' + : case 'getmethod' + : case 'getindex' + ---@param source parser.object + : call(function (source) + local name + local keyName = guide.getKeyName(source) + if not keyName then + return + end + if source.node._globalNode then + local parentName = source.node._globalNode:getName() + if parentName == '_G' then + name = keyName + else + name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName) + end + elseif source.node.special == '_G' then + name = keyName + end + local uri = guide.getUri(source) + local global = vm.declareGlobal('variable', name, uri) + global:addGet(uri, source) + source._globalNode = global + + local nxt = source.next + if nxt then + compileObject(nxt) + end + end) + : case 'call' + : call(function (source) + if source.node.special == 'rawset' + or source.node.special == 'rawget' then + if not source.args then + return + end + local g = source.args[1] + local key = source.args[2] + if g and key and g.special == '_G' then + local name = guide.getKeyName(key) + if name then + local uri = guide.getUri(source) + local global = vm.declareGlobal('variable', name, uri) + if source.node.special == 'rawset' then + global:addSet(uri, source) + source.value = source.args[3] + else + global:addGet(uri, source) + end + source._globalNode = global + + local nxt = source.next + if nxt then + compileObject(nxt) + end + end + end + end + end) + : case 'doc.class' + ---@param source parser.object + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + local class = vm.declareGlobal('type', name, uri) + class:addSet(uri, source) + source._globalNode = class + + if source.signs then + source._sign = vm.createSign() + for _, sign in ipairs(source.signs) do + source._sign:addSign(vm.compileNode(sign)) + end + if source.extends then + for _, ext in ipairs(source.extends) do + if ext.type == 'doc.type.table' then + ext._generic = vm.createGeneric(ext, source._sign) + end + end + end + end + end) + : case 'doc.alias' + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + local alias = vm.declareGlobal('type', name, uri) + alias:addSet(uri, source) + source._globalNode = alias + + if source.signs then + source._sign = vm.createSign() + for _, sign in ipairs(source.signs) do + source._sign:addSign(vm.compileNode(sign)) + end + source.extends._generic = vm.createGeneric(source.extends, source._sign) + end + end) + : case 'doc.type.name' + : call(function (source) + local uri = guide.getUri(source) + local name = source[1] + local type = vm.declareGlobal('type', name, uri) + type:addGet(uri, source) + source._globalNode = type + end) + : case 'doc.extends.name' + : call(function (source) + local uri = guide.getUri(source) + local name = source[1] + local class = vm.declareGlobal('type', name, uri) + class:addGet(uri, source) + source._globalNode = class + end) + + +---@alias vm.global.cate '"variable"' | '"type"' + +---@param cate vm.global.cate +---@param name string +---@param uri? uri +---@return vm.global +function vm.declareGlobal(cate, name, uri) + local key = cate .. '|' .. name + if uri then + globalSubs[uri][key] = true + end + if not allGlobals[key] then + allGlobals[key] = createGlobal(name, cate) + end + return allGlobals[key] +end + +---@param cate vm.global.cate +---@param name string +---@param field? string +---@return vm.global? +function vm.getGlobal(cate, name, field) + local key = cate .. '|' .. name + if field then + key = key .. vm.ID_SPLITE .. field + end + return allGlobals[key] +end + +---@param cate vm.global.cate +---@param name string +---@return vm.global[] +function vm.getGlobalFields(cate, name) + local globals = {} + local key = cate .. '|' .. name + + local clock = os.clock() + for gid, global in pairs(allGlobals) do + if gid ~= key + and util.stringStartWith(gid, key) + and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE + and not gid:find(vm.ID_SPLITE, #key + 2) then + globals[#globals+1] = global + end + end + local cost = os.clock() - clock + if cost > 0.1 then + log.warn('global-manager getFields cost %.3f', cost) + end + + return globals +end + +---@param cate vm.global.cate +---@return vm.global[] +function vm.getGlobals(cate) + local globals = {} + + local clock = os.clock() + for gid, global in pairs(allGlobals) do + if util.stringStartWith(gid, cate) + and not gid:find(vm.ID_SPLITE) then + globals[#globals+1] = global + end + end + local cost = os.clock() - clock + if cost > 0.1 then + log.warn('global-manager getGlobals cost %.3f', cost) + end + + return globals +end + +---@param suri uri +---@param cate vm.global.cate +---@return parser.object[] +function vm.getGlobalSets(suri, cate) + local globals = vm.getGlobals(cate) + local result = {} + for _, global in ipairs(globals) do + local sets = global:getSets(suri) + for _, set in ipairs(sets) do + result[#result+1] = set + end + end + return result +end + +---@param suri uri +---@param cate vm.global.cate +---@param name string +---@return boolean +function vm.hasGlobalSets(suri, cate, name) + local global = vm.getGlobal(cate, name) + if not global then + return false + end + local sets = global:getSets(suri) + if #sets == 0 then + return false + end + return true +end + +---@param source parser.object +function compileObject(source) + if source._globalNode ~= nil then + return + end + source._globalNode = false + compilerGlobalSwitch(source.type, source) +end + +---@param source parser.object +local function compileSelf(source) + if source.parent.type ~= 'funcargs' then + return + end + ---@type parser.object + local node = source.parent.parent and source.parent.parent.parent and source.parent.parent.parent.node + if not node then + return + end + local fields = vm.getLocalFields(source) + if not fields then + return + end + local nodeLocalID = vm.getLocalID(node) + local globalNode = node._globalNode + if not nodeLocalID and not globalNode then + return + end + for _, field in ipairs(fields) do + if field.type == 'setfield' then + local key = guide.getKeyName(field) + if key then + if nodeLocalID then + local myID = nodeLocalID .. vm.ID_SPLITE .. key + vm.insertLocalID(myID, field) + end + if globalNode then + local myID = globalNode:getName() .. vm.ID_SPLITE .. key + local myGlobal = vm.declareGlobal('variable', myID, guide.getUri(node)) + myGlobal:addSet(guide.getUri(node), field) + end + end + end + end +end + +---@param source parser.object +local function compileAst(source) + local env = guide.getENV(source) + if not env then + return + end + compileObject(env) + guide.eachSpecialOf(source, 'rawset', function (src) + compileObject(src.parent) + end) + guide.eachSpecialOf(source, 'rawget', function (src) + compileObject(src.parent) + end) + guide.eachSourceTypes(source.docs, { + 'doc.class', + 'doc.alias', + 'doc.type.name', + 'doc.extends.name', + }, function (src) + compileObject(src) + end) + + --[[ + local mt + function mt:xxx() + self.a = 1 + end + + mt.a --> find this definition + ]] + guide.eachSourceType(source, 'self', function (src) + compileSelf(src) + end) +end + +---@param uri uri +local function dropUri(uri) + local globalSub = globalSubs[uri] + globalSubs[uri] = nil + for key in pairs(globalSub) do + local global = allGlobals[key] + if global then + global:dropUri(uri) + if not global:isAlive() then + allGlobals[key] = nil + end + end + end +end + +for uri in files.eachFile() do + local state = files.getState(uri) + if state then + compileAst(state.ast) + end +end + +files.watch(function (ev, uri) + if ev == 'update' then + dropUri(uri) + local state = files.getState(uri) + if state then + compileAst(state.ast) + end + end + if ev == 'remove' then + dropUri(uri) + end +end) diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 2a64ed52..fabc9828 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -1,11 +1,9 @@ local util = require 'utility' local config = require 'config' local guide = require 'parser.guide' +---@class vm local vm = require 'vm.vm' ----@class vm.infer-manager -local m = {} - ---@class vm.infer ---@field views table<string, boolean> ---@field cachedView? string @@ -21,7 +19,7 @@ mt._hasDocFunction = false mt._isParam = false mt._isLocal = false -m.NULL = setmetatable({}, mt) +vm.NULL = setmetatable({}, mt) local inferSorted = { ['boolean'] = - 100, @@ -52,7 +50,7 @@ local viewNodeSwitch = util.switch() : call(function (source, infer) if source.type == 'table' then if #source == 1 and source[1].type == 'varargs' then - local node = m.getInfer(source[1]):view() + local node = vm.getInfer(source[1]):view() return ('%s[]'):format(node) end end @@ -90,7 +88,7 @@ local viewNodeSwitch = util.switch() if source.signs then local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = m.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view() end return ('%s<%s>'):format(source[1], table.concat(buf, ', ')) else @@ -99,7 +97,7 @@ local viewNodeSwitch = util.switch() end) : case 'generic' : call(function (source, infer) - return m.getInfer(source.proto):view() + return vm.getInfer(source.proto):view() end) : case 'doc.generic.name' : call(function (source, infer) @@ -108,7 +106,7 @@ local viewNodeSwitch = util.switch() : case 'doc.type.array' : call(function (source, infer) infer._hasClass = true - local view = m.getInfer(source.node):view() + local view = vm.getInfer(source.node):view() if source.node.type == 'doc.type' then view = '(' .. view .. ')' end @@ -119,7 +117,7 @@ local viewNodeSwitch = util.switch() infer._hasClass = true local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = m.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view() end return ('%s<%s>'):format(source.node[1], table.concat(buf, ', ')) end) @@ -144,20 +142,23 @@ local viewNodeSwitch = util.switch() local argView = '' local regView = '' for i, arg in ipairs(source.args) do + local argNode = vm.compileNode(arg) + local isOptional = argNode:isOptional() + if isOptional then + argNode = argNode:copy() + argNode:removeOptional() + end args[i] = string.format('%s%s: %s' , arg.name[1] - , arg.optional and '?' or '' - , m.getInfer(arg):view() + , isOptional and '?' or '' + , vm.getInfer(argNode):view() ) end if #args > 0 then argView = table.concat(args, ', ') end for i, ret in ipairs(source.returns) do - rets[i] = string.format('%s%s' - , m.getInfer(ret):view() - , ret.optional and '?' or '' - ) + rets[i] = vm.getInfer(ret):view() end if #rets > 0 then regView = ':' .. table.concat(rets, ', ') @@ -165,16 +166,21 @@ local viewNodeSwitch = util.switch() return ('fun(%s)%s'):format(argView, regView) end) ----@param source parser.object +---@param source parser.object | vm.node ---@return vm.infer -function m.getInfer(source) - local node = vm.compileNode(source) +function vm.getInfer(source) + local node + if source.type == 'vm.node' then + node = source + else + node = vm.compileNode(source) + end if node.lastInfer then return node.lastInfer end local infer = setmetatable({ node = node, - uri = guide.getUri(source), + uri = source.type ~= 'vm.node' and guide.getUri(source), }, mt) node.lastInfer = infer @@ -199,24 +205,24 @@ function mt:_trim() if self._hasTable and not self._hasClass then self.views['table'] = true end - if self._hasClass then - self:_eraseAlias() - end end -function mt:_eraseAlias() - local expandAlias = config.get(self.uri, 'Lua.hover.expandAlias') +---@param uri uri +---@return table<string, true> +function mt:_eraseAlias(uri) + local drop = {} + local expandAlias = config.get(uri, 'Lua.hover.expandAlias') for n in self.node:eachObject() do if n.type == 'global' and n.cate == 'type' then - for _, set in ipairs(n:getSets(self.uri)) do + for _, set in ipairs(n:getSets(uri)) do if set.type == 'doc.alias' then if expandAlias then - self.views[n.name] = nil + drop[n.name] = true else for _, ext in ipairs(set.extends.types) do local view = viewNodeSwitch(ext.type, ext, {}) if view and view ~= n.name then - self.views[view] = nil + drop[view] = true end end end @@ -224,6 +230,7 @@ function mt:_eraseAlias() end end end + return drop end ---@param tp string @@ -273,17 +280,16 @@ function mt:view(default, uri) return 'any' end - if not next(self.views) then - return default or 'unknown' - end - - if self.cachedView then - return self.cachedView + local drop + if self._hasClass then + drop = self:_eraseAlias(uri or self.uri) end local array = {} for view in pairs(self.views) do - array[#array+1] = view + if not drop or not drop[view] then + array[#array+1] = view + end end table.sort(array, function (a, b) @@ -298,22 +304,29 @@ function mt:view(default, uri) local max = #array local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit') - if max > limit then - local view = string.format('%s...(+%d)' - , table.concat(array, '|', 1, limit) - , max - limit - ) - - self.cachedView = view - - return view + local view + if #array == 0 then + view = default or 'unknown' else - local view = table.concat(array, '|') - - self.cachedView = view + if max > limit then + view = string.format('%s...(+%d)' + , table.concat(array, '|', 1, limit) + , max - limit + ) + else + view = table.concat(array, '|') + end + end - return view + if self.node:isOptional() then + if max > 1 then + view = '(' .. view .. ')?' + else + view = view .. '?' + end end + + return view end function mt:eachView() @@ -324,10 +337,10 @@ end ---@param other vm.infer ---@return vm.infer function mt:merge(other) - if self == m.NULL then + if self == vm.NULL then return other end - if other == m.NULL then + if other == vm.NULL then return self end @@ -390,8 +403,6 @@ end ---@param source parser.object ---@return string? -function m.viewObject(source) +function vm.viewObject(source) return viewNodeSwitch(source.type, source, {}) end - -return m diff --git a/script/vm/init.lua b/script/vm/init.lua index 0058c698..f5003c11 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -1,4 +1,7 @@ local vm = require 'vm.vm' + +---@alias vm.object parser.object | vm.global | vm.generic + require 'vm.compiler' require 'vm.value' require 'vm.node' @@ -8,5 +11,10 @@ require 'vm.field' require 'vm.doc' require 'vm.type' require 'vm.library' -require 'vm.manager' +require 'vm.runner' +require 'vm.infer' +require 'vm.generic' +require 'vm.sign' +require 'vm.local-id' +require 'vm.global' return vm diff --git a/script/vm/library.lua b/script/vm/library.lua index 49f7adb0..e7bf4f42 100644 --- a/script/vm/library.lua +++ b/script/vm/library.lua @@ -13,24 +13,3 @@ function vm.getLibraryName(source) end return nil end - -local globalLibraryNames = { - 'arg', 'assert', 'error', 'collectgarbage', 'dofile', '_G', 'getfenv', - 'getmetatable', 'ipairs', 'load', 'loadfile', 'loadstring', - 'module', 'next', 'pairs', 'pcall', 'print', 'rawequal', - 'rawget', 'rawlen', 'rawset', 'select', 'setfenv', - 'setmetatable', 'tonumber', 'tostring', 'type', '_VERSION', - 'warn', 'xpcall', 'require', 'unpack', 'bit32', 'coroutine', - 'debug', 'io', 'math', 'os', 'package', 'string', 'table', - 'utf8', 'newproxy', -} -local globalLibraryNamesMap -function vm.isGlobalLibraryName(name) - if not globalLibraryNamesMap then - globalLibraryNamesMap = {} - for _, v in ipairs(globalLibraryNames) do - globalLibraryNamesMap[v] = true - end - end - return globalLibraryNamesMap[name] or false -end diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua index 728de301..80c68769 100644 --- a/script/vm/local-id.lua +++ b/script/vm/local-id.lua @@ -1,13 +1,13 @@ local util = require 'utility' local guide = require 'parser.guide' +---@class vm local vm = require 'vm.vm' ---@class parser.object ---@field _localID string ---@field _localIDs table<string, parser.object[]> ----@class vm.local-id -local m = {} +local compileLocalID, getLocal local compileSwitch = util.switch() : case 'local' @@ -18,13 +18,13 @@ local compileSwitch = util.switch() return end for _, ref in ipairs(source.ref) do - m.compileLocalID(ref) + compileLocalID(ref) end end) : case 'getlocal' : call(function (source) source._localID = ('%d'):format(source.node.start) - m.compileLocalID(source.next) + compileLocalID(source.next) end) : case 'getfield' : case 'setfield' @@ -40,7 +40,7 @@ local compileSwitch = util.switch() source._localID = parentID .. vm.ID_SPLITE .. key source.field._localID = source._localID if source.type == 'getfield' then - m.compileLocalID(source.next) + compileLocalID(source.next) end end) : case 'getmethod' @@ -57,7 +57,7 @@ local compileSwitch = util.switch() source._localID = parentID .. vm.ID_SPLITE .. key source.method._localID = source._localID if source.type == 'getmethod' then - m.compileLocalID(source.next) + compileLocalID(source.next) end end) : case 'getindex' @@ -74,7 +74,7 @@ local compileSwitch = util.switch() source._localID = parentID .. vm.ID_SPLITE .. key source.index._localID = source._localID if source.type == 'setindex' then - m.compileLocalID(source.next) + compileLocalID(source.next) end end) @@ -82,7 +82,7 @@ local leftSwitch = util.switch() : case 'field' : case 'method' : call(function (source) - return m.getLocal(source.parent) + return getLocal(source.parent) end) : case 'getfield' : case 'setfield' @@ -91,24 +91,36 @@ local leftSwitch = util.switch() : case 'getindex' : case 'setindex' : call(function (source) - return m.getLocal(source.node) + return getLocal(source.node) end) : case 'getlocal' : call(function (source) return source.node end) : case 'local' + : case 'self' : call(function (source) return source end) ---@param source parser.object ---@return parser.object? -function m.getLocal(source) +function getLocal(source) return leftSwitch(source.type, source) end -function m.compileLocalID(source) +---@param id string +---@param source parser.object +function vm.insertLocalID(id, source) + local root = guide.getRoot(source) + if not root._localIDs then + root._localIDs = util.multiTable(2) + end + local sources = root._localIDs[id] + sources[#sources+1] = source +end + +function compileLocalID(source) if not source then return end @@ -117,37 +129,33 @@ function m.compileLocalID(source) return end compileSwitch(source.type, source) - if not source._localID then + local id = source._localID + if not id then return end - local root = guide.getRoot(source) - if not root._localIDs then - root._localIDs = util.multiTable(2) - end - local sources = root._localIDs[source._localID] - sources[#sources+1] = source + vm.insertLocalID(id, source) end ---@param source parser.object ----@return string|boolean -function m.getID(source) +---@return string? +function vm.getLocalID(source) if source._localID ~= nil then return source._localID end source._localID = false - local loc = m.getLocal(source) + local loc = getLocal(source) if not loc then return source._localID end - m.compileLocalID(loc) + compileLocalID(loc) return source._localID end ---@param source parser.object ---@param key? string ---@return parser.object[]? -function m.getSources(source, key) - local id = m.getID(source) +function vm.getLocalSources(source, key) + local id = vm.getLocalID(source) if not id then return nil end @@ -166,8 +174,8 @@ end ---@param source parser.object ---@return parser.object[] -function m.getFields(source) - local id = m.getID(source) +function vm.getLocalFields(source) + local id = vm.getLocalID(source) if not id then return nil end @@ -195,5 +203,3 @@ function m.getFields(source) end return fields end - -return m diff --git a/script/vm/local-manager.lua b/script/vm/local-manager.lua deleted file mode 100644 index 51bafb24..00000000 --- a/script/vm/local-manager.lua +++ /dev/null @@ -1,40 +0,0 @@ -local util = require 'utility' -local guide = require 'parser.guide' - ----@class vm.local-node -local m = {} ----@type table<uri, parser.object[]> -m.locals = util.multiTable(2) ----@type table<parser.object, table<parser.object, boolean>> -m.localSubs = util.multiTable(2, function () - return setmetatable({}, util.MODE_K) -end) ----@type table<parser.object, boolean> -m.allLocals = {} - ----@param source parser.object -function m.declareLocal(source) - if m.allLocals[source] then - return - end - m.allLocals[source] = true - local uri = guide.getUri(source) - local locals = m.locals[uri] - locals[#locals+1] = source -end - ----@param uri uri -function m.dropUri(uri) - local locals = m.locals[uri] - m.locals[uri] = nil - for _, loc in ipairs(locals) do - m.allLocals[loc] = nil - local localSubs = m.localSubs[loc] - m.localSubs[loc] = nil - for source in pairs(localSubs) do - source._node = nil - end - end -end - -return m diff --git a/script/vm/manager.lua b/script/vm/manager.lua deleted file mode 100644 index 58255fca..00000000 --- a/script/vm/manager.lua +++ /dev/null @@ -1,26 +0,0 @@ - -local files = require 'files' -local globalManager = require 'vm.global-manager' -local localManager = require 'vm.local-manager' - ----@alias vm.object parser.object | vm.global | vm.generic - ----@class vm.state -local m = {} - -files.watch(function (ev, uri) - if ev == 'update' then - globalManager.dropUri(uri) - localManager.dropUri(uri) - local state = files.getState(uri) - if state then - globalManager.compileAst(state.ast) - end - end - if ev == 'remove' then - globalManager.dropUri(uri) - localManager.dropUri(uri) - end -end) - -return m diff --git a/script/vm/node.lua b/script/vm/node.lua index 6906da7e..e76542aa 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -1,5 +1,4 @@ local files = require 'files' -local localMgr = require 'vm.local-manager' ---@class vm local vm = require 'vm.vm' local ws = require 'workspace.workspace' @@ -8,15 +7,14 @@ local ws = require 'workspace.workspace' vm.nodeCache = {} ---@class vm.node +---@field [integer] vm.object local mt = {} mt.__index = mt +mt.id = 0 mt.type = 'vm.node' mt.optional = nil mt.lastInfer = nil mt.data = nil ----@type vm.node[] -mt._childs = nil -mt._locked = false ---@param node vm.node | vm.object function mt:merge(node) @@ -30,20 +28,10 @@ function mt:merge(node) if node:isOptional() then self.optional = true end - if node._locked then - if not self._childs then - self._childs = {} - end - if not self._childs[node] then - self._childs[#self._childs+1] = node - self._childs[node] = true - end - else - for _, obj in ipairs(node) do - if not self[obj] then - self[obj] = true - self[#self+1] = obj - end + for _, obj in ipairs(node) do + if not self[obj] then + self[obj] = true + self[#self+1] = obj end end else @@ -54,84 +42,25 @@ function mt:merge(node) end end -function mt:_each(mark, callback) - if mark[self] then - return - end - mark[self] = true - for i = 1, #self do - callback(self[i]) - end - local childs = self._childs - if not childs then - return - end - for i = 1, #childs do - local child = childs[i] - if not child:isLocked() then - child:_each(mark, callback) - end - end -end - -function mt:_expand() - local childs = self._childs - if not childs then - return - end - self._childs = nil - - local mark = {} - mark[self] = true - - local function insert(obj) - if not self[obj] then - self[obj] = true - self[#self+1] = obj - end - end - - for i = 1, #childs do - local child = childs[i] - if child:isLocked() then - if not self._childs then - self._childs = {} - end - if not self._childs[child] then - self._childs[#self._childs+1] = child - self._childs[child] = true - end - else - child:_each(mark, insert) - end - end -end - ---@return boolean function mt:isEmpty() - self:_expand() return #self == 0 end +function mt:clear() + self.optional = nil + for i, c in ipairs(self) do + self[i] = nil + self[c] = nil + end +end + ---@param n integer ---@return vm.object? function mt:get(n) - self:_expand() return self[n] end -function mt:lock() - self._locked = true -end - -function mt:unlock() - self._locked = false -end - -function mt:isLocked() - return self._locked == true -end - function mt:setData(k, v) if not self.data then self.data = {} @@ -147,49 +76,143 @@ function mt:getData(k) end function mt:addOptional() - if self:isOptional() then - return self - end self.optional = true end function mt:removeOptional() - if not self:isOptional() then - return self - end - self:_expand() - for i = #self, 1, -1 do - local n = self[i] - if n.type == 'nil' - or (n.type == 'boolean' and n[1] == false) - or (n.type == 'doc.type.boolean' and n[1] == false) then - self[i] = self[#self] - self[#self] = nil - end - end + self:remove 'nil' end ---@return boolean function mt:isOptional() - if self.optional ~= nil then - return self.optional + return self.optional == true +end + +---@return boolean +function mt:hasFalsy() + if self.optional then + return true end - self:_expand() for _, c in ipairs(self) do if c.type == 'nil' + or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') + or (c.type == 'global' and c.cate == 'type' and c.name == 'false') or (c.type == 'boolean' and c[1] == false) or (c.type == 'doc.type.boolean' and c[1] == false) then - self.optional = true return true end end - self.optional = false return false end +---@return boolean +function mt:isNullable() + if self.optional then + return true + end + if #self == 0 then + return true + end + for _, c in ipairs(self) do + if c.type == 'nil' + or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') + or (c.type == 'global' and c.cate == 'type' and c.name == 'any') then + return true + end + end + return false +end + +---@return vm.node +function mt:setTruthy() + if self.optional == true then + self.optional = nil + end + local hasBoolean + for index = #self, 1, -1 do + local c = self[index] + if c.type == 'nil' + or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') + or (c.type == 'global' and c.cate == 'type' and c.name == 'false') + or (c.type == 'boolean' and c[1] == false) + or (c.type == 'doc.type.boolean' and c[1] == false) then + table.remove(self, index) + self[c] = nil + goto CONTINUE + end + if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean') + or (c.type == 'boolean' or c.type == 'doc.type.boolean') then + hasBoolean = true + table.remove(self, index) + self[c] = nil + goto CONTINUE + end + ::CONTINUE:: + end + if hasBoolean then + self[#self+1] = vm.declareGlobal('type', 'true') + end +end + +---@return vm.node +function mt:setFalsy() + if self.optional == false then + self.optional = nil + end + local hasBoolean + for index = #self, 1, -1 do + local c = self[index] + if c.type == 'nil' + or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') + or (c.type == 'global' and c.cate == 'type' and c.name == 'false') + or (c.type == 'boolean' and c[1] == true) + or (c.type == 'doc.type.boolean' and c[1] == true) then + goto CONTINUE + end + if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean') + or (c.type == 'boolean' or c.type == 'doc.type.boolean') then + hasBoolean = true + table.remove(self, index) + self[c] = nil + end + ::CONTINUE:: + end + if hasBoolean then + self[#self+1] = vm.declareGlobal('type', 'false') + end +end + +---@param name string +function mt:remove(name) + if name == 'nil' and self.optional == true then + self.optional = nil + end + for index = #self, 1, -1 do + local c = self[index] + if (c.type == 'global' and c.cate == 'type' and c.name == name) + or (c.type == name) + or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer')) + or (c.type == 'doc.type.boolean' and name == 'boolean') + or (c.type == 'doc.type.table' and name == 'table') + or (c.type == 'doc.type.array' and name == 'table') + or (c.type == 'doc.type.function' and name == 'function') then + table.remove(self, index) + self[c] = nil + end + end +end + +---@param node vm.node +function mt:removeNode(node) + for _, c in ipairs(node) do + if c.type == 'global' and c.cate == 'type' then + self:remove(c.name) + end + end +end + ---@return fun():vm.object function mt:eachObject() - self:_expand() local i = 0 return function () i = i + 1 @@ -197,12 +220,21 @@ function mt:eachObject() end end ----@param source parser.object | vm.generic +---@return vm.node +function mt:copy() + return vm.createNode(self) +end + +---@param source vm.object ---@param node vm.node | vm.object ---@param cover? boolean function vm.setNode(source, node, cover) if not node then - error('Can not set nil node') + if TEST then + error('Can not set nil node') + else + log.error('Can not set nil node') + end end if source.type == 'global' then error('Can not set node to global') @@ -216,13 +248,14 @@ function vm.setNode(source, node, cover) me:merge(node) else if node.type == 'vm.node' then - vm.nodeCache[source] = node + vm.nodeCache[source] = node:copy() else vm.nodeCache[source] = vm.createNode(node) end end end +---@param source vm.object ---@return vm.node? function vm.getNode(source) return vm.nodeCache[source] @@ -256,11 +289,16 @@ function vm.clearNodeCache() vm.nodeCache = {} end +local ID = 0 + ---@param a? vm.node | vm.object ---@param b? vm.node | vm.object ---@return vm.node function vm.createNode(a, b) - local node = setmetatable({}, mt) + ID = ID + 1 + local node = setmetatable({ + id = ID, + }, mt) if a then node:merge(a) end diff --git a/script/vm/ref.lua b/script/vm/ref.lua index 65e8fdab..545c294a 100644 --- a/script/vm/ref.lua +++ b/script/vm/ref.lua @@ -2,8 +2,6 @@ local vm = require 'vm.vm' local util = require 'utility' local guide = require 'parser.guide' -local localID = require 'vm.local-id' -local globalMgr = require 'vm.global-manager' local files = require 'files' local await = require 'await' local progress = require 'progress' @@ -242,7 +240,7 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchByLocalID(source, pushResult) - local idSources = localID.getSources(source) + local idSources = vm.getLocalSources(source) if not idSources then return end @@ -291,7 +289,7 @@ end ---@async ---@param source parser.object ----@param fileNotify fun(uri: uri): boolean +---@param fileNotify? fun(uri: uri): boolean function vm.getRefs(source, fileNotify) local results = {} local mark = {} diff --git a/script/vm/runner.lua b/script/vm/runner.lua new file mode 100644 index 00000000..9fe0f172 --- /dev/null +++ b/script/vm/runner.lua @@ -0,0 +1,444 @@ +---@class vm +local vm = require 'vm.vm' +local guide = require 'parser.guide' + +---@class vm.runner +---@field loc parser.object +---@field mainBlock parser.object +---@field blocks table<parser.object, true> +---@field steps vm.runner.step[] +local mt = {} +mt.__index = mt +mt.index = 1 + +---@class parser.object +---@field _casts parser.object[] + +---@class vm.runner.step +---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast' +---@field pos integer +---@field order? integer +---@field node? vm.node +---@field object? parser.object +---@field name? string +---@field cast? parser.object +---@field tag? string +---@field copy? boolean +---@field new? boolean +---@field ref1? vm.runner.step +---@field ref2? vm.runner.step + +---@param filter parser.object +---@param outStep vm.runner.step +---@param blockStep vm.runner.step +function mt:_compileNarrowByFilter(filter, outStep, blockStep) + if not filter then + return + end + if filter.type == 'paren' then + if filter.exp then + self:_compileNarrowByFilter(filter.exp, outStep, blockStep) + end + return + end + if filter.type == 'unary' then + if not filter.op + or not filter[1] then + return + end + if filter.op.type == 'not' then + local exp = filter[1] + if exp.type == 'getlocal' and exp.node == self.loc then + self.steps[#self.steps+1] = { + type = 'falsy', + pos = filter.finish, + new = true, + } + self.steps[#self.steps+1] = { + type = 'truthy', + pos = filter.finish, + ref1 = outStep, + } + end + end + elseif filter.type == 'binary' then + if not filter.op + or not filter[1] + or not filter[2] then + return + end + if filter.op.type == 'and' then + local dummyStep = { + type = 'save', + copy = true, + ref1 = outStep, + pos = filter.start - 1, + } + self.steps[#self.steps+1] = dummyStep + self:_compileNarrowByFilter(filter[1], dummyStep, blockStep) + self:_compileNarrowByFilter(filter[2], dummyStep, blockStep) + end + if filter.op.type == 'or' then + self:_compileNarrowByFilter(filter[1], outStep, blockStep) + local dummyStep = { + type = 'push', + copy = true, + ref1 = outStep, + pos = filter.op.finish, + } + self.steps[#self.steps+1] = dummyStep + self:_compileNarrowByFilter(filter[2], outStep, dummyStep) + self.steps[#self.steps+1] = { + type = 'push', + tag = 'or reset', + ref1 = blockStep, + pos = filter.finish, + } + end + if filter.op.type == '==' + or filter.op.type == '~=' then + local loc, exp + for i = 1, 2 do + loc = filter[i] + if loc.type == 'getlocal' and loc.node == self.loc then + exp = filter[i % 2 + 1] + break + end + end + if not loc or not exp then + return + end + if guide.isLiteral(exp) then + if filter.op.type == '==' then + self.steps[#self.steps+1] = { + type = 'remove', + name = exp.type, + pos = filter.finish, + ref1 = outStep, + } + self.steps[#self.steps+1] = { + type = 'as', + name = exp.type, + pos = filter.finish, + new = true, + } + end + if filter.op.type == '~=' then + self.steps[#self.steps+1] = { + type = 'as', + name = exp.type, + pos = filter.finish, + ref1 = outStep, + } + self.steps[#self.steps+1] = { + type = 'remove', + name = exp.type, + pos = filter.finish, + new = true, + } + end + end + end + else + if filter.type == 'getlocal' and filter.node == self.loc then + self.steps[#self.steps+1] = { + type = 'truthy', + pos = filter.finish, + new = true, + } + self.steps[#self.steps+1] = { + type = 'falsy', + pos = filter.finish, + ref1 = outStep, + } + end + end +end + +---@param block parser.object +function mt:_compileBlock(block) + if self.blocks[block] then + return + end + self.blocks[block] = true + if block == self.mainBlock then + return + end + + local parentBlock = guide.getParentBlock(block) + self:_compileBlock(parentBlock) + + if block.type == 'if' then + ---@type vm.runner.step[] + local finals = {} + for i, childBlock in ipairs(block) do + local blockStep = { + type = 'save', + tag = 'block', + copy = true, + pos = childBlock.start, + } + local outStep = { + type = 'save', + tag = 'out', + copy = true, + pos = childBlock.start, + } + self.steps[#self.steps+1] = blockStep + self.steps[#self.steps+1] = outStep + self.steps[#self.steps+1] = { + type = 'push', + ref1 = blockStep, + pos = childBlock.start, + } + self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep) + if not childBlock.hasReturn + and not childBlock.hasGoTo + and not childBlock.hasBreak then + local finalStep = { + type = 'save', + pos = childBlock.finish, + tag = 'final #' .. i, + } + finals[#finals+1] = finalStep + self.steps[#self.steps+1] = finalStep + end + self.steps[#self.steps+1] = { + type = 'push', + tag = 'reset child', + ref1 = outStep, + pos = childBlock.finish, + } + end + self.steps[#self.steps+1] = { + type = 'push', + tag = 'reset if', + pos = block.finish, + copy = true, + } + for _, final in ipairs(finals) do + self.steps[#self.steps+1] = { + type = 'merge', + ref2 = final, + pos = block.finish, + } + end + end + + if block.type == 'function' + or block.type == 'while' + or block.type == 'loop' + or block.type == 'in' + or block.type == 'repeat' + or block.type == 'for' then + local savePoint = { + type = 'save', + copy = true, + pos = block.start, + } + self.steps[#self.steps+1] = { + type = 'push', + copy = true, + pos = block.start, + } + self.steps[#self.steps+1] = savePoint + self.steps[#self.steps+1] = { + type = 'push', + pos = block.finish, + ref1 = savePoint, + } + end +end + +---@return parser.object[] +function mt:_getCasts() + local root = guide.getRoot(self.loc) + if not root._casts then + root._casts = {} + local docs = root.docs + for _, doc in ipairs(docs) do + if doc.type == 'doc.cast' and doc.loc then + root._casts[#root._casts+1] = doc + end + end + end + return root._casts +end + +function mt:_preCompile() + local startPos = self.loc.start + local finishPos = 0 + + for _, ref in ipairs(self.loc.ref) do + self.steps[#self.steps+1] = { + type = 'object', + object = ref, + pos = ref.range or ref.start, + } + if ref.start > finishPos then + finishPos = ref.start + end + local block = guide.getParentBlock(ref) + self:_compileBlock(block) + end + + for i, step in ipairs(self.steps) do + if step.type ~= 'object' then + step.order = i + end + end + + local casts = self:_getCasts() + for _, cast in ipairs(casts) do + if cast.loc[1] == self.loc[1] + and cast.start > startPos + and cast.finish < finishPos + and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then + self.steps[#self.steps+1] = { + type = 'cast', + cast = cast, + pos = cast.start, + } + end + end + + table.sort(self.steps, function (a, b) + if a.pos == b.pos then + return (a.order or 0) < (b.order or 0) + else + return a.pos < b.pos + end + end) +end + +---@param loc parser.object +---@param node vm.node +---@return vm.node +local function checkAssert(loc, node) + local parent = loc.parent + if parent.type == 'binary' then + if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then + local exp + for i = 1, 2 do + if parent[i] == loc then + exp = parent[i % 2 + 1] + end + end + if exp and guide.isLiteral(exp) then + local callargs = parent.parent + if callargs.type == 'callargs' + and callargs.parent.node.special == 'assert' + and callargs[1] == parent then + if parent.op.type == '~=' then + node:remove(exp.type) + end + if parent.op.type == '==' then + node = vm.compileNode(exp) + end + end + end + end + end + if parent.type == 'callargs' + and parent.parent.node.special == 'assert' + and parent[1] == loc then + node:setTruthy() + end + return node +end + +---@param callback fun(src: parser.object, node: vm.node) +function mt:launch(callback) + local topNode = vm.getNode(self.loc):copy() + for _, step in ipairs(self.steps) do + local node = step.ref1 and step.ref1.node or topNode + if step.type == 'truthy' then + if step.new then + node = node:copy() + topNode = node + end + node:setTruthy() + elseif step.type == 'falsy' then + if step.new then + node = node:copy() + topNode = node + end + node:setFalsy() + elseif step.type == 'as' then + if step.new then + topNode = vm.createNode(vm.getGlobal('type', step.name)) + else + node:clear() + node:merge(vm.getGlobal('type', step.name)) + end + elseif step.type == 'add' then + if step.new then + node = node:copy() + topNode = node + end + node:merge(vm.getGlobal('type', step.name)) + elseif step.type == 'remove' then + if step.new then + node = node:copy() + topNode = node + end + node:remove(step.name) + elseif step.type == 'object' then + topNode = callback(step.object, node) or node + if step.object.type == 'getlocal' then + topNode = checkAssert(step.object, node) + end + elseif step.type == 'save' then + if step.copy then + node = node:copy() + end + step.node = node + elseif step.type == 'push' then + if step.copy then + node = node:copy() + end + topNode = node + elseif step.type == 'merge' then + node:merge(step.ref2.node) + elseif step.type == 'cast' then + topNode = node:copy() + for _, cast in ipairs(step.cast.casts) do + if cast.mode == '+' then + if cast.optional then + topNode:addOptional() + end + if cast.extends then + topNode:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + topNode:removeOptional() + end + if cast.extends then + topNode:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + topNode:clear() + topNode:merge(vm.compileNode(cast.extends)) + end + end + end + end + end +end + +---@param loc parser.object +---@return vm.runner +function vm.createRunner(loc) + local self = setmetatable({ + loc = loc, + mainBlock = guide.getParentBlock(loc), + blocks = {}, + steps = {}, + }, mt) + + self:_preCompile() + + return self +end diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 2d45a5a7..fe112bc2 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -1,6 +1,6 @@ local guide = require 'parser.guide' +---@class vm local vm = require 'vm.vm' -local infer = require 'vm.infer' ---@class vm.sign ---@field parent parser.object @@ -16,12 +16,12 @@ end ---@param uri uri ---@param args parser.object +---@param removeGeneric true? ---@return table<string, vm.node> -function mt:resolve(uri, args) +function mt:resolve(uri, args, removeGeneric) if not args then return nil end - local globalMgr = require 'vm.global-manager' local resolved = {} ---@param object parser.object @@ -33,7 +33,7 @@ function mt:resolve(uri, args) -- 'number' -> `T` for n in node:eachObject() do if n.type == 'string' then - local type = globalMgr.declareGlobal('type', n[1], guide.getUri(n)) + local type = vm.declareGlobal('type', n[1], guide.getUri(n)) resolved[key] = vm.createNode(type, resolved[key]) end end @@ -48,6 +48,19 @@ function mt:resolve(uri, args) -- number[] -> T[] resolve(object.node, vm.compileNode(n.node)) end + if n.type == 'doc.type.table' then + -- { [integer]: number } -> T[] + local tvalueNode = vm.getTableValue(uri, node, 'integer') + if tvalueNode then + resolve(object.node, tvalueNode) + end + end + if n.type == 'global' and n.cate == 'type' then + -- ---@field [integer]: number -> T[] + vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), false, function (field) + resolve(object.node, vm.compileNode(field.extends)) + end) + end end end if object.type == 'doc.type.table' then @@ -98,7 +111,7 @@ function mt:resolve(uri, args) goto CONTINUE end end - local view = infer.viewObject(obj) + local view = vm.viewObject(obj) if view then knownTypes[view] = true end @@ -114,10 +127,10 @@ function mt:resolve(uri, args) local function buildArgNode(argNode, knownTypes) local newArgNode = vm.createNode() for n in argNode:eachObject() do - if argNode:isOptional() and vm.isFalsy(n) then + if argNode:hasFalsy() then goto CONTINUE end - local view = infer.viewObject(n) + local view = vm.viewObject(n) if knownTypes[view] then goto CONTINUE end @@ -156,7 +169,7 @@ function mt:resolve(uri, args) end ---@return vm.sign -return function () +function vm.createSign() local genericMgr = setmetatable({ signList = {}, }, mt) diff --git a/script/vm/type.lua b/script/vm/type.lua index fa02d19e..c3264993 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -1,4 +1,3 @@ -local globalMgr = require 'vm.global-manager' ---@class vm local vm = require 'vm.vm' @@ -9,10 +8,10 @@ local vm = require 'vm.vm' ---@return boolean function vm.isSubType(uri, child, parent, mark) if type(parent) == 'string' then - parent = vm.createNode(globalMgr.getGlobal('type', parent)) + parent = vm.createNode(vm.getGlobal('type', parent)) end if type(child) == 'string' then - child = vm.createNode(globalMgr.getGlobal('type', child)) + child = vm.createNode(vm.getGlobal('type', child)) end if not child or not parent then @@ -134,7 +133,7 @@ function vm.getTableKey(uri, tnode, vnode) end end if tn.type == 'doc.type.array' then - result:merge(globalMgr.getGlobal('type', 'integer')) + result:merge(vm.declareGlobal('type', 'integer')) end if tn.type == 'table' then for _, field in ipairs(tn) do @@ -144,10 +143,10 @@ function vm.getTableKey(uri, tnode, vnode) end end if field.type == 'tablefield' then - result:merge(globalMgr.getGlobal('type', 'string')) + result:merge(vm.declareGlobal('type', 'string')) end if field.type == 'tableexp' then - result:merge(globalMgr.getGlobal('type', 'integer')) + result:merge(vm.declareGlobal('type', 'integer')) end end end diff --git a/script/vm/value.lua b/script/vm/value.lua index a784be2a..d29ca9d0 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -17,7 +17,16 @@ function vm.test(source) hasTrue = true end if n[1] == false then - hasTrue = false + hasFalse = true + end + end + if n.type == 'global' and n.cate == 'type' then + if n.name == 'true' then + hasTrue = true + end + if n.name == 'false' + or n.name == 'nil' then + hasFalse = true end end if n.type == 'nil' then @@ -41,28 +50,9 @@ function vm.test(source) end end ----@param source parser.object ----@return boolean -function vm.isFalsy(source) - if source.type == 'nil' then - return true - end - if source.type == 'boolean' - or source.type == 'doc.type.boolean' then - return source[1] == false - end - return false -end - ---@param v vm.object ---@return string? local function getUnique(v) - if v.type == 'local' then - return ('loc:%s@%d'):format(guide.getUri(v), v.start) - end - if v.type == 'global' then - return ('%s:%s'):format(v.cate, v.name) - end if v.type == 'boolean' then if v[1] == nil then return false diff --git a/script/vm/vm.lua b/script/vm/vm.lua index 3c1762bf..8117d311 100644 --- a/script/vm/vm.lua +++ b/script/vm/vm.lua @@ -23,6 +23,7 @@ function m.getSpecial(source) return source.special end +---@return string? function m.getKeyName(source) if not source then return nil |