summaryrefslogtreecommitdiff
path: root/server-beta/src/vm
diff options
context:
space:
mode:
Diffstat (limited to 'server-beta/src/vm')
-rw-r--r--server-beta/src/vm/eachDef.lua65
-rw-r--r--server-beta/src/vm/eachField.lua163
-rw-r--r--server-beta/src/vm/eachRef.lua504
-rw-r--r--server-beta/src/vm/getGlobal.lua6
-rw-r--r--server-beta/src/vm/getGlobals.lua45
-rw-r--r--server-beta/src/vm/getLibrary.lua60
-rw-r--r--server-beta/src/vm/getLinks.lua48
-rw-r--r--server-beta/src/vm/getValue.lua452
-rw-r--r--server-beta/src/vm/init.lua10
-rw-r--r--server-beta/src/vm/special.lua0
-rw-r--r--server-beta/src/vm/vm.lua81
11 files changed, 1434 insertions, 0 deletions
diff --git a/server-beta/src/vm/eachDef.lua b/server-beta/src/vm/eachDef.lua
new file mode 100644
index 00000000..0274cbee
--- /dev/null
+++ b/server-beta/src/vm/eachDef.lua
@@ -0,0 +1,65 @@
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local files = require 'files'
+
+local function checkPath(source, info)
+ if source.type == 'goto' then
+ return true
+ end
+ local src = info.source
+ local mode = guide.getPath(source, src)
+ if not mode then
+ return true
+ end
+ if mode == 'before' then
+ return false
+ end
+ return true
+end
+
+function vm.eachDef(source, callback)
+ local results = {}
+ local valueUris = {}
+ local sourceUri = guide.getRoot(source).uri
+ vm.eachRef(source, function (info)
+ if info.mode == 'declare'
+ or info.mode == 'set'
+ or info.mode == 'return'
+ or info.mode == 'value' then
+ results[#results+1] = info
+ local src = info.source
+ if info.mode == 'return' then
+ local uri = guide.getRoot(src).uri
+ valueUris[uri] = info.source
+ end
+ end
+ end)
+
+ for _, info in ipairs(results) do
+ local src = info.source
+ local destUri = guide.getRoot(src).uri
+ -- 如果是同一个文件,则检查位置关系后放行
+ if sourceUri == destUri then
+ if checkPath(source, info) then
+ callback(info)
+ end
+ goto CONTINUE
+ end
+ -- 如果是global或field,则直接放行(因为无法确定顺序)
+ if src.type == 'setindex'
+ or src.type == 'setfield'
+ or src.type == 'setmethod'
+ or src.type == 'tablefield'
+ or src.type == 'tableindex'
+ or src.type == 'setglobal' then
+ callback(info)
+ goto CONTINUE
+ end
+ -- 如果不是同一个文件,则必须在该文件 return 后才放行
+ if valueUris[destUri] then
+ callback(info)
+ goto CONTINUE
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/server-beta/src/vm/eachField.lua b/server-beta/src/vm/eachField.lua
new file mode 100644
index 00000000..549a7dec
--- /dev/null
+++ b/server-beta/src/vm/eachField.lua
@@ -0,0 +1,163 @@
+local guide = require 'parser.guide'
+local vm = require 'vm.vm'
+
+local function ofTabel(value, callback)
+ for _, field in ipairs(value) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ callback {
+ source = field,
+ key = guide.getKeyName(field),
+ value = field.value,
+ mode = 'set',
+ }
+ end
+ end
+end
+
+local function ofENV(source, callback)
+ if source.type == 'getlocal' then
+ local parent = source.parent
+ if parent.type == 'getfield'
+ or parent.type == 'getmethod'
+ or parent.type == 'getindex' then
+ callback {
+ source = parent,
+ key = guide.getKeyName(parent),
+ mode = 'get',
+ }
+ end
+ elseif source.type == 'getglobal' then
+ callback {
+ source = source,
+ key = guide.getKeyName(source),
+ mode = 'get',
+ }
+ elseif source.type == 'setglobal' then
+ callback {
+ source = source,
+ key = guide.getKeyName(source),
+ mode = 'set',
+ value = source.value,
+ }
+ end
+end
+
+local function ofSpecialArg(source, callback)
+ local args = source.parent
+ local call = args.parent
+ local func = call.node
+ local name = func.special
+ if name == 'rawset' then
+ if args[1] == source and args[2] then
+ callback {
+ source = call,
+ key = guide.getKeyName(args[2]),
+ value = args[3],
+ mode = 'set',
+ }
+ end
+ elseif name == 'rawget' then
+ if args[1] == source and args[2] then
+ callback {
+ source = call,
+ key = guide.getKeyName(args[2]),
+ mode = 'get',
+ }
+ end
+ elseif name == 'setmetatable' then
+ if args[1] == source and args[2] then
+ vm.eachField(args[2], function (info)
+ if info.key == 's|__index' and info.value then
+ vm.eachField(info.value, callback)
+ end
+ end)
+ end
+ end
+end
+
+local function ofVar(source, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'getfield'
+ or parent.type == 'getmethod'
+ or parent.type == 'getindex' then
+ callback {
+ source = parent,
+ key = guide.getKeyName(parent),
+ mode = 'get',
+ }
+ return
+ end
+ if parent.type == 'setfield'
+ or parent.type == 'setmethod'
+ or parent.type == 'setindex' then
+ callback {
+ source = parent,
+ key = guide.getKeyName(parent),
+ value = parent.value,
+ mode = 'set',
+ }
+ return
+ end
+ if parent.type == 'callargs' then
+ ofSpecialArg(source, callback)
+ end
+end
+
+local function eachField(source, callback)
+ vm.eachRef(source, function (info)
+ local src = info.source
+ if src.tag == '_ENV' then
+ if src.ref then
+ for _, ref in ipairs(src.ref) do
+ ofENV(ref, callback)
+ end
+ end
+ elseif src.type == 'getlocal'
+ or src.type == 'getglobal'
+ or src.type == 'getfield'
+ or src.type == 'getmethod'
+ or src.type == 'getindex' then
+ ofVar(src, callback)
+ elseif src.type == 'table' then
+ ofTabel(src, callback)
+ end
+ end)
+end
+
+--- 获取所有的field
+function vm.eachField(source, callback)
+ local cache = vm.cache.eachField[source]
+ if cache then
+ for i = 1, #cache do
+ callback(cache[i])
+ end
+ return
+ end
+ local unlock = vm.lock('eachField', source)
+ if not unlock then
+ return
+ end
+ cache = {}
+ vm.cache.eachField[source] = cache
+ local mark = {}
+ eachField(source, function (info)
+ local src = info.source
+ if mark[src] then
+ return
+ end
+ mark[src] = true
+ cache[#cache+1] = info
+ end)
+ unlock()
+ vm.eachRef(source, function (info)
+ local src = info.source
+ vm.cache.eachField[src] = cache
+ end)
+ for i = 1, #cache do
+ callback(cache[i])
+ end
+end
diff --git a/server-beta/src/vm/eachRef.lua b/server-beta/src/vm/eachRef.lua
new file mode 100644
index 00000000..543a0c09
--- /dev/null
+++ b/server-beta/src/vm/eachRef.lua
@@ -0,0 +1,504 @@
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm.vm'
+
+local function ofCall(func, index, callback)
+ vm.eachRef(func, function (info)
+ local src = info.source
+ local returns
+ if info.mode == 'main' then
+ returns = src.returns
+ else
+ local funcDef = src.value
+ returns = funcDef and funcDef.returns
+ end
+ if returns then
+ -- 搜索函数第 index 个返回值
+ for _, rtn in ipairs(returns) do
+ local val = rtn[index]
+ if val then
+ callback {
+ source = val,
+ mode = 'return',
+ }
+ vm.eachRef(val, callback)
+ end
+ end
+ end
+ end)
+end
+
+local function ofCallSelect(call, index, callback)
+ local slc = call.parent
+ if slc.index == index then
+ vm.eachRef(slc.parent, callback)
+ return
+ end
+ if call.extParent then
+ for i = 1, #call.extParent do
+ slc = call.extParent[i]
+ if slc.index == index then
+ vm.eachRef(slc.parent, callback)
+ return
+ end
+ end
+ end
+end
+
+local function ofReturn(rtn, index, callback)
+ local func = guide.getParentFunction(rtn)
+ if not func then
+ return
+ end
+ -- 搜索函数调用的第 index 个接收值
+ if func.type == 'main' then
+ vm.eachRef(func, callback)
+ else
+ vm.eachRef(func, function (info)
+ local source = info.source
+ local call = source.parent
+ if not call or call.type ~= 'call' then
+ return
+ end
+ ofCallSelect(call, index, callback)
+ end)
+ end
+end
+
+local function ofSpecialCall(call, func, index, callback)
+ local name = func.special
+ if name == 'setmetatable' then
+ if index == 1 then
+ local args = call.args
+ if args[1] then
+ vm.eachRef(args[1], callback)
+ end
+ if args[2] then
+ vm.eachField(args[2], function (info)
+ if info.key == 's|__index' then
+ vm.eachRef(info.source, callback)
+ if info.value then
+ vm.eachRef(info.value, callback)
+ end
+ end
+ end)
+ end
+ end
+ elseif name == 'require' then
+ if index == 1 then
+ local result = vm.getLinkUris(call)
+ if result then
+ local myUri = guide.getRoot(call).uri
+ for _, uri in ipairs(result) do
+ if not files.eq(uri, myUri) then
+ local ast = files.getAst(uri)
+ if ast then
+ ofCall(ast.ast, 1, callback)
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+local function ofValue(value, callback)
+ if value.type == 'select' then
+ -- 检查函数返回值
+ local call = value.vararg
+ if call.type == 'call' then
+ ofCall(call.node, value.index, callback)
+ ofSpecialCall(call, call.node, value.index, callback)
+ end
+ return
+ end
+
+ if value.type == 'table'
+ or value.type == 'string'
+ or value.type == 'number'
+ or value.type == 'boolean'
+ or value.type == 'nil'
+ or value.type == 'function' then
+ callback {
+ source = value,
+ mode = 'value',
+ }
+ end
+
+ vm.eachRef(value, callback)
+
+ local parent = value.parent
+ if parent.type == 'local'
+ or parent.type == 'setglobal'
+ or parent.type == 'setlocal'
+ or parent.type == 'setfield'
+ or parent.type == 'setmethod'
+ or parent.type == 'setindex'
+ or parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ if parent.value == value then
+ vm.eachRef(parent, callback)
+ end
+ end
+ if parent.type == 'return' then
+ for i = 1, #parent do
+ if parent[i] == value then
+ ofReturn(parent, i, callback)
+ break
+ end
+ end
+ end
+end
+
+local function ofSelf(loc, callback)
+ -- self 的2个特殊引用位置:
+ -- 1. 当前方法定义时的对象(mt)
+ local method = loc.method
+ local node = method.node
+ vm.eachRef(node, callback)
+ -- 2. 调用该方法时传入的对象
+end
+
+--- 自己作为赋值的值
+local function asValue(source, callback)
+ local parent = source.parent
+ if parent and parent.value == source then
+ if guide.getKeyString(parent) == '__index' then
+ if parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ local t = parent.parent
+ local args = t.parent
+ if args[2] == t then
+ local call = args.parent
+ local func = call.node
+ if func.special == 'setmetatable' then
+ vm.eachRef(args[1], callback)
+ end
+ end
+ end
+ end
+ end
+end
+
+local function getCallRecvs(call)
+ local parent = call.parent
+ if parent.type ~= 'select' then
+ return nil
+ end
+ local exParent = call.exParent
+ local recvs = {}
+ recvs[1] = parent.parent
+ if exParent then
+ for _, p in ipairs(exParent) do
+ recvs[#recvs+1] = p.parent
+ end
+ end
+ return recvs
+end
+
+--- 自己作为函数的参数
+local function asArg(source, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'callargs' then
+ local call = parent.parent
+ local func = call.node
+ local name = func.special
+ if name == 'setmetatable' then
+ if parent[1] == source then
+ if parent[2] then
+ vm.eachField(parent[2], function (info)
+ if info.key == 's|__index' then
+ vm.eachRef(info.source, callback)
+ if info.value then
+ vm.eachRef(info.value, callback)
+ end
+ end
+ end)
+ end
+ end
+ local recvs = getCallRecvs(call)
+ if recvs and recvs[1] then
+ vm.eachRef(recvs[1], callback)
+ end
+ end
+ end
+end
+
+local function ofLocal(loc, callback)
+ -- 方法中的 self 使用了一个虚拟的定义位置
+ if loc.tag ~= 'self' then
+ callback {
+ source = loc,
+ mode = 'declare',
+ }
+ end
+ if loc.ref then
+ for _, ref in ipairs(loc.ref) do
+ if ref.type == 'getlocal' then
+ callback {
+ source = ref,
+ mode = 'get',
+ }
+ asValue(ref, callback)
+ elseif ref.type == 'setlocal' then
+ callback {
+ source = ref,
+ mode = 'set',
+ }
+ if ref.value then
+ ofValue(ref.value, callback)
+ end
+ end
+ end
+ end
+ if loc.tag == 'self' then
+ ofSelf(loc, callback)
+ end
+ if loc.value then
+ ofValue(loc.value, callback)
+ end
+ if loc.tag == '_ENV' and loc.ref then
+ for _, ref in ipairs(loc.ref) do
+ if ref.type == 'getlocal' then
+ local parent = ref.parent
+ if parent.type == 'getfield'
+ or parent.type == 'getindex' then
+ if guide.getKeyName(parent) == '_G' then
+ callback {
+ source = parent,
+ mode = 'get',
+ }
+ end
+ end
+ elseif ref.type == 'getglobal' then
+ if guide.getKeyString(ref) == '_G' then
+ callback {
+ source = ref,
+ mode = 'get',
+ }
+ end
+ end
+ end
+ end
+end
+
+local function ofGlobal(source, callback)
+ local key = guide.getKeyName(source)
+ local node = source.node
+ if node.tag == '_ENV' then
+ local uris = files.findGlobals(key)
+ for _, uri in ipairs(uris) do
+ local ast = files.getAst(uri)
+ local globals = vm.getGlobals(ast.ast)
+ if globals[key] then
+ for _, info in ipairs(globals[key]) do
+ callback(info)
+ if info.value then
+ ofValue(info.value, callback)
+ end
+ end
+ end
+ end
+ else
+ vm.eachField(node, function (info)
+ if key == info.key then
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ if info.value then
+ ofValue(info.value, callback)
+ end
+ end
+ end)
+ end
+end
+
+local function ofField(source, callback)
+ local parent = source.parent
+ local key = guide.getKeyName(source)
+ if parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ local tbl = parent.parent
+ vm.eachField(tbl, function (info)
+ if key == info.key then
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ if info.value then
+ ofValue(info.value, callback)
+ end
+ end
+ end)
+ else
+ local node = parent.node
+ vm.eachField(node, function (info)
+ if key == info.key then
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ if info.value then
+ ofValue(info.value, callback)
+ end
+ end
+ end)
+ end
+end
+
+local function ofLiteral(source, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ ofField(source, callback)
+ end
+end
+
+local function ofLabel(source, callback)
+ callback {
+ source = source,
+ mode = 'set',
+ }
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback {
+ source = ref,
+ mode = 'get',
+ }
+ end
+ end
+end
+
+local function ofGoTo(source, callback)
+ local name = source[1]
+ local label = guide.getLabel(source, name)
+ if label then
+ ofLabel(label, callback)
+ end
+end
+
+local function ofMain(source, callback)
+ callback {
+ source = source,
+ mode = 'main',
+ }
+ local myUri = source.uri
+ local uris = files.findLinkTo(myUri)
+ if not uris then
+ return
+ end
+ for _, uri in ipairs(uris) do
+ local ast = files.getAst(uri)
+ if ast then
+ local links = vm.getLinks(ast.ast)
+ if links then
+ for linkUri, calls in pairs(links) do
+ if files.eq(linkUri, myUri) then
+ for i = 1, #calls do
+ ofCallSelect(calls[i], 1, callback)
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+local function eachRef(source, callback)
+ local stype = source.type
+ if stype == 'local' then
+ ofLocal(source, callback)
+ elseif stype == 'getlocal'
+ or stype == 'setlocal' then
+ ofLocal(source.node, callback)
+ elseif stype == 'setglobal'
+ or stype == 'getglobal' then
+ ofGlobal(source, callback)
+ elseif stype == 'field'
+ or stype == 'method' then
+ ofField(source, callback)
+ elseif stype == 'setfield'
+ or stype == 'getfield' then
+ ofField(source.field, callback)
+ elseif stype == 'setmethod'
+ or stype == 'getmethod' then
+ ofField(source.method, callback)
+ elseif stype == 'number'
+ or stype == 'boolean'
+ or stype == 'string' then
+ ofLiteral(source, callback)
+ elseif stype == 'goto' then
+ ofGoTo(source, callback)
+ elseif stype == 'label' then
+ ofLabel(source, callback)
+ elseif stype == 'table'
+ or stype == 'function' then
+ ofValue(source, callback)
+ elseif stype == 'main' then
+ ofMain(source, callback)
+ end
+ asArg(source, callback)
+end
+
+--- 判断2个对象是否拥有相同的引用
+function vm.isSameRef(a, b)
+ local cache = vm.cache.eachRef[a]
+ if cache then
+ -- 相同引用的source共享同一份cache
+ return cache == vm.cache.eachRef[b]
+ else
+ return vm.eachRef(a, function (info)
+ if info.source == b then
+ return true
+ end
+ end) or false
+ end
+end
+
+--- 获取所有的引用
+function vm.eachRef(source, callback)
+ local cache = vm.cache.eachRef[source]
+ if cache then
+ for i = 1, #cache do
+ local res = callback(cache[i])
+ if res ~= nil then
+ return res
+ end
+ end
+ return
+ end
+ local unlock = vm.lock('eachRef', source)
+ if not unlock then
+ return
+ end
+ cache = {}
+ vm.cache.eachRef[source] = cache
+ local mark = {}
+ eachRef(source, function (info)
+ local src = info.source
+ if mark[src] then
+ return
+ end
+ mark[src] = true
+ cache[#cache+1] = info
+ end)
+ unlock()
+ for i = 1, #cache do
+ local src = cache[i].source
+ vm.cache.eachRef[src] = cache
+ end
+ for i = 1, #cache do
+ local res = callback(cache[i])
+ if res ~= nil then
+ return res
+ end
+ end
+end
diff --git a/server-beta/src/vm/getGlobal.lua b/server-beta/src/vm/getGlobal.lua
new file mode 100644
index 00000000..373c907e
--- /dev/null
+++ b/server-beta/src/vm/getGlobal.lua
@@ -0,0 +1,6 @@
+local vm = require 'vm.vm'
+
+function vm.getGlobal(source)
+ vm.getGlobals(source)
+ return vm.cache.getGlobal[source]
+end
diff --git a/server-beta/src/vm/getGlobals.lua b/server-beta/src/vm/getGlobals.lua
new file mode 100644
index 00000000..699dd270
--- /dev/null
+++ b/server-beta/src/vm/getGlobals.lua
@@ -0,0 +1,45 @@
+local guide = require 'parser.guide'
+local vm = require 'vm.vm'
+
+local function getGlobals(root)
+ local env = guide.getENV(root)
+ local cache = {}
+ local mark = {}
+ vm.eachField(env, function (info)
+ local src = info.source
+ if mark[src] then
+ return
+ end
+ mark[src] = true
+ local name = info.key
+ if not name then
+ return
+ end
+ if not cache[name] then
+ cache[name] = {
+ key = name,
+ mode = {},
+ }
+ end
+ cache[name][#cache[name]+1] = info
+ cache[name].mode[info.mode] = true
+ vm.cache.getGlobal[src] = name
+ end)
+ return cache
+end
+
+function vm.getGlobals(source)
+ source = guide.getRoot(source)
+ local cache = vm.cache.getGlobals[source]
+ if cache ~= nil then
+ return cache
+ end
+ local unlock = vm.lock('getGlobals', source)
+ if not unlock then
+ return nil
+ end
+ cache = getGlobals(source) or false
+ vm.cache.getGlobals[source] = cache
+ unlock()
+ return cache
+end
diff --git a/server-beta/src/vm/getLibrary.lua b/server-beta/src/vm/getLibrary.lua
new file mode 100644
index 00000000..08f015a6
--- /dev/null
+++ b/server-beta/src/vm/getLibrary.lua
@@ -0,0 +1,60 @@
+local vm = require 'vm.vm'
+local library = require 'library'
+local guide = require 'parser.guide'
+
+local function checkStdLibrary(source)
+ local globalName = vm.getGlobal(source)
+ if not globalName then
+ return nil
+ end
+ local name = globalName:match '^s|(.+)$'
+ if library.global[name] then
+ return library.global[name]
+ end
+end
+
+local function getLibrary(source)
+ local lib = checkStdLibrary(source)
+ if lib then
+ return lib
+ end
+ return vm.eachRef(source, function (info)
+ local src = info.source
+ if src.type ~= 'getfield'
+ and src.type ~= 'getmethod'
+ and src.type ~= 'getindex' then
+ return
+ end
+ local node = src.node
+ local nodeGlobalName = vm.getGlobal(node)
+ if not nodeGlobalName then
+ return
+ end
+ local nodeName = nodeGlobalName:match '^s|(.+)$'
+ local nodeLib = library.global[nodeName]
+ if not nodeLib then
+ return
+ end
+ if not nodeLib.child then
+ return
+ end
+ local key = guide.getKeyString(src)
+ local defLib = nodeLib.child[key]
+ return defLib
+ end)
+end
+
+function vm.getLibrary(source)
+ local cache = vm.cache.getLibrary[source]
+ if cache ~= nil then
+ return cache
+ end
+ local unlock = vm.lock('getLibrary', source)
+ if not unlock then
+ return
+ end
+ cache = getLibrary(source) or false
+ vm.cache.getLibrary[source] = cache
+ unlock()
+ return cache
+end
diff --git a/server-beta/src/vm/getLinks.lua b/server-beta/src/vm/getLinks.lua
new file mode 100644
index 00000000..6875771f
--- /dev/null
+++ b/server-beta/src/vm/getLinks.lua
@@ -0,0 +1,48 @@
+local guide = require 'parser.guide'
+local vm = require 'vm.vm'
+
+local function getLinks(root)
+ local cache = {}
+ local ok
+ guide.eachSpecialOf(root, 'require', function (source)
+ local call = source.parent
+ if call.type == 'call' then
+ local uris = vm.getLinkUris(call)
+ if uris then
+ ok = true
+ for i = 1, #uris do
+ local uri = uris[i]
+ if not cache[uri] then
+ cache[uri] = {}
+ end
+ cache[uri][#cache[uri]+1] = call
+ end
+ end
+ end
+ end)
+ if not ok then
+ return nil
+ end
+ return cache
+end
+
+function vm.getLinks(source)
+ source = guide.getRoot(source)
+ local cache = vm.cache.getLinks[source]
+ if cache ~= nil then
+ return cache
+ end
+ local unlock = vm.lock('getLinks', source)
+ if not unlock then
+ return nil
+ end
+ local clock = os.clock()
+ cache = getLinks(source) or false
+ local passed = os.clock() - clock
+ if passed > 0.1 then
+ log.warn(('getLinks takes [%.3f] sec!'):format(passed))
+ end
+ vm.cache.getLinks[source] = cache
+ unlock()
+ return cache
+end
diff --git a/server-beta/src/vm/getValue.lua b/server-beta/src/vm/getValue.lua
new file mode 100644
index 00000000..b13d822d
--- /dev/null
+++ b/server-beta/src/vm/getValue.lua
@@ -0,0 +1,452 @@
+local vm = require 'vm.vm'
+
+local typeSort = {
+ ['boolean'] = 1,
+ ['string'] = 2,
+ ['integer'] = 3,
+ ['number'] = 4,
+ ['table'] = 5,
+ ['function'] = 6,
+ ['nil'] = math.maxinteger,
+}
+
+NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
+
+local function merge(a, b)
+ local t = {}
+ for i = 1, #a do
+ t[#t+1] = a[i]
+ end
+ for i = 1, #b do
+ t[#t+1] = b[i]
+ end
+ return t
+end
+
+local function checkLiteral(source)
+ if source.type == 'string' then
+ return {
+ type = 'string',
+ value = source[1],
+ source = source,
+ }
+ elseif source.type == 'nil' then
+ return {
+ type = 'nil',
+ value = NIL,
+ source = source,
+ }
+ elseif source.type == 'boolean' then
+ return {
+ type = 'boolean',
+ value = source[1],
+ source = source,
+ }
+ elseif source.type == 'number' then
+ if math.type(source[1]) == 'integer' then
+ return {
+ type = 'integer',
+ value = source[1],
+ source = source,
+ }
+ else
+ return {
+ type = 'number',
+ value = source[1],
+ source = source,
+ }
+ end
+ elseif source.type == 'table' then
+ return {
+ type = 'table',
+ source = source,
+ }
+ elseif source.type == 'function' then
+ return {
+ type = 'function',
+ source = source,
+ }
+ end
+end
+
+local function checkUnary(source)
+ if source.type ~= 'unary' then
+ return
+ end
+ local op = source.op
+ if op.type == 'not' then
+ local isTrue = vm.isTrue(source[1])
+ local value = nil
+ if isTrue == true then
+ value = false
+ elseif isTrue == false then
+ value = true
+ end
+ return {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ elseif op.type == '#' then
+ return {
+ type = 'integer',
+ source = source,
+ }
+ elseif op.type == '~' then
+ local l = vm.getLiteral(source[1], 'integer')
+ return {
+ type = 'integer',
+ value = l and ~l or nil,
+ source = source,
+ }
+ elseif op.type == '-' then
+ local v = vm.getLiteral(source[1], 'integer')
+ if v then
+ return {
+ type = 'integer',
+ value = - v,
+ source = source,
+ }
+ end
+ v = vm.getLiteral(source[1], 'number')
+ return {
+ type = 'number',
+ value = v and -v or nil,
+ source = source,
+ }
+ end
+end
+
+local function checkBinary(source)
+ if source.type ~= 'binary' then
+ return
+ end
+ local op = source.op
+ if op.type == 'and' then
+ local isTrue = vm.checkTrue(source[1])
+ if isTrue == true then
+ return vm.getValue(source[2])
+ elseif isTrue == false then
+ return vm.getValue(source[1])
+ else
+ return merge(
+ vm.getValue(source[1]),
+ vm.getValue(source[2])
+ )
+ end
+ elseif op.type == 'or' then
+ local isTrue = vm.checkTrue(source[1])
+ if isTrue == true then
+ return vm.getValue(source[1])
+ elseif isTrue == false then
+ return vm.getValue(source[2])
+ else
+ return merge(
+ vm.getValue(source[1]),
+ vm.getValue(source[2])
+ )
+ end
+ elseif op.type == '==' then
+ local value = vm.isSameValue(source[1], source[2])
+ if value ~= nil then
+ return {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ end
+ local isSame = vm.isSameRef(source[1], source[2])
+ if isSame == true then
+ value = true
+ else
+ value = nil
+ end
+ return {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ elseif op.type == '~=' then
+ local value = vm.isSameValue(source[1], source[2])
+ if value ~= nil then
+ return {
+ type = 'boolean',
+ value = not value,
+ source = source,
+ }
+ end
+ local isSame = vm.isSameRef(source[1], source[2])
+ if isSame == true then
+ value = false
+ else
+ value = nil
+ end
+ return {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ elseif op.type == '<=' then
+ elseif op.type == '>='
+ or op.type == '<'
+ or op.type == '>' then
+ return 'boolean'
+ end
+ if op.type == '|'
+ or op.type == '~'
+ or op.type == '&'
+ or op.type == '<<'
+ or op.type == '>>' then
+ return 'integer'
+ end
+ if op.type == '..' then
+ return 'string'
+ end
+ if op.type == '^'
+ or op.type == '/' then
+ return 'number'
+ end
+ -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数
+ if op.type == '+'
+ or op.type == '-'
+ or op.type == '*'
+ or op.type == '%'
+ or op.type == '//' then
+ if hasType('integer', vm.getValue(source[1]))
+ and hasType('integer', vm.getValue(source[2])) then
+ return 'integer'
+ else
+ return 'number'
+ end
+ end
+end
+
+local function checkValue(source)
+ if source.value then
+ return vm.getValue(source.value)
+ end
+end
+
+local function checkCall(result, source)
+ if not source.parent then
+ return
+ end
+ if source.parent.type ~= 'call' then
+ return
+ end
+ if source.parent.node == source then
+ merge(result, 'function')
+ return
+ end
+end
+
+local function checkNext(result, source)
+ local next = source.next
+ if not next then
+ return
+ end
+ if next.type == 'getfield'
+ or next.type == 'getindex'
+ or next.type == 'getmethod'
+ or next.type == 'setfield'
+ or next.type == 'setindex'
+ or next.type == 'setmethod' then
+ merge(result, 'table')
+ end
+end
+
+local function checkDef(result, source)
+ vm.eachDef(source, function (info)
+ local src = info.source
+ local tp = vm.getValue(src)
+ if tp then
+ merge(result, tp)
+ end
+ end)
+end
+
+local function typeInference(source)
+ local tp = checkLiteral(source)
+ or checkValue(source)
+ or checkUnary(source)
+ or checkBinary(source)
+ if tp then
+ return tp
+ end
+
+ local result = {}
+
+ checkCall(result, source)
+ checkNext(result, source)
+ checkDef(result, source)
+
+ return dump(result)
+end
+
+local function getValue(source)
+ local result = checkLiteral(source)
+ if result then
+ return { result }
+ end
+ local results = checkValue(source)
+ or checkUnary(source)
+ or checkBinary(source)
+ if results then
+ return results
+ end
+end
+
+function vm.checkTrue(source)
+ local values = vm.getValue(source)
+ if not values then
+ return
+ end
+ -- 当前认为的结果
+ local current
+ for i = 1, #values do
+ -- 新的结果
+ local new
+ local v = values[i]
+ if v.type == 'nil' then
+ new = false
+ elseif v.type == 'boolean' then
+ if v.value == true then
+ new = true
+ elseif v.value == false then
+ new = false
+ end
+ end
+ if new ~= nil then
+ if current == nil then
+ current = new
+ else
+ -- 如果2个结果完全相反,则返回 nil 表示不确定
+ if new ~= current then
+ return nil
+ end
+ end
+ end
+ end
+ return current
+end
+
+--- 拥有某个类型的值
+function vm.eachValueType(source, type, callback)
+ local values = vm.getValue(source)
+ if not values then
+ return
+ end
+ for i = 1, #values do
+ local v = values[i]
+ if v.type == type then
+ local res = callback(v)
+ if res ~= nil then
+ return res
+ end
+ end
+ end
+end
+
+--- 获取特定类型的字面量值
+function vm.getLiteral(source, type)
+ local values = vm.getValue(source)
+ if not values then
+ return nil
+ end
+ for i = 1, #values do
+ local v = values[i]
+ if v.type == type and v.value ~= nil then
+ return v.value
+ end
+ end
+ return nil
+end
+
+function vm.isSameValue(a, b)
+ local valuesA = vm.getValue(a)
+ local valuesB = vm.getValue(b)
+ if valuesA == valuesB and valuesA ~= nil then
+ return true
+ end
+ local values = {}
+ for i = 1, #valuesA do
+ local value = valuesA[i]
+ local literal = value.value
+ if literal then
+ values[literal] = false
+ end
+ end
+ for i = 1, #valuesB do
+ local value = valuesA[i]
+ local literal = value.value
+ if literal then
+ if values[literal] == nil then
+ return false
+ end
+ values[literal] = true
+ end
+ end
+ for k, v in pairs(values) do
+ if v == false then
+ return false
+ end
+ end
+ return true
+end
+
+function vm.typeInference(source)
+ local values = vm.getValue(source)
+ if not values then
+ return 'any'
+ end
+ local types = {}
+ for _ = 1, #values do
+ local tp = values.type
+ if not types[tp] then
+ types[tp] = true
+ types[#types+1] = tp
+ end
+ end
+ if #types == 0 then
+ return 'any'
+ end
+ if #types == 1 then
+ return types[1]
+ end
+ table.sort(types, function (a, b)
+ local sa = typeSort[a]
+ local sb = typeSort[b]
+ if sa and sb then
+ return sa < sb
+ end
+ if not sa and not sb then
+ return a < b
+ end
+ if sa and not sb then
+ return true
+ end
+ if not sa and sb then
+ return false
+ end
+ return false
+ end)
+ return table.concat(types, '|')
+end
+
+function vm.getValue(source)
+ if not source then
+ return
+ end
+ local cache = vm.cache.getValue[source]
+ if cache ~= nil then
+ return cache
+ end
+ local unlock = vm.lock('getValue', source)
+ if not unlock then
+ return
+ end
+ cache = getValue(source) or false
+ vm.cache.getValue[source] = cache
+ unlock()
+ return cache
+end
diff --git a/server-beta/src/vm/init.lua b/server-beta/src/vm/init.lua
new file mode 100644
index 00000000..bf63db1d
--- /dev/null
+++ b/server-beta/src/vm/init.lua
@@ -0,0 +1,10 @@
+local vm = require 'vm.vm'
+require 'vm.eachField'
+require 'vm.eachRef'
+require 'vm.eachDef'
+require 'vm.getGlobals'
+require 'vm.getLinks'
+require 'vm.getGlobal'
+require 'vm.getLibrary'
+require 'vm.getValue'
+return vm
diff --git a/server-beta/src/vm/special.lua b/server-beta/src/vm/special.lua
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/server-beta/src/vm/special.lua
diff --git a/server-beta/src/vm/vm.lua b/server-beta/src/vm/vm.lua
new file mode 100644
index 00000000..23a691df
--- /dev/null
+++ b/server-beta/src/vm/vm.lua
@@ -0,0 +1,81 @@
+local guide = require 'parser.guide'
+local util = require 'utility'
+
+local setmetatable = setmetatable
+local assert = assert
+local require = require
+local type = type
+
+_ENV = nil
+
+local specials = {
+ ['_G'] = true,
+ ['rawset'] = true,
+ ['rawget'] = true,
+ ['setmetatable'] = true,
+ ['require'] = true,
+ ['dofile'] = true,
+ ['loadfile'] = true,
+}
+
+---@class vm
+local m = {}
+
+function m.lock(tp, source)
+ if m.locked[tp][source] then
+ return nil
+ end
+ m.locked[tp][source] = true
+ return function ()
+ m.locked[tp][source] = nil
+ end
+end
+
+--- 获取link的uri
+function m.getLinkUris(call)
+ local workspace = require 'workspace'
+ local func = call.node
+ local name = func.special
+ if name == 'require' then
+ local args = call.args
+ if not args[1] then
+ return nil
+ end
+ local literal = guide.getLiteral(args[1])
+ if type(literal) ~= 'string' then
+ return nil
+ end
+ return workspace.findUrisByRequirePath(literal, true)
+ end
+end
+
+m.cacheTracker = setmetatable({}, { __mode = 'kv' })
+
+--- 刷新缓存
+function m.refreshCache()
+ if m.cache then
+ m.cache.dead = true
+ end
+ m.cache = {
+ eachRef = {},
+ eachField = {},
+ getGlobals = {},
+ getLinks = {},
+ getGlobal = {},
+ specialName = {},
+ getLibrary = {},
+ getValue = {},
+ specials = nil,
+ }
+ m.locked = {
+ eachRef = {},
+ eachField = {},
+ getGlobals = {},
+ getLinks = {},
+ getLibrary = {},
+ getValue = {},
+ }
+ m.cacheTracker[m.cache] = true
+end
+
+return m