summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2019-12-16 21:31:58 +0800
committer最萌小汐 <sumneko@hotmail.com>2019-12-16 21:31:58 +0800
commit70ae6b13b620148e954dbdf9b5563fe04d4e52b2 (patch)
tree08c13b758b2e8aaee3a32f7a61fc7ca7103f5338
parentd65bc946566756b7627898118e467ae8bd234341 (diff)
downloadlua-language-server-70ae6b13b620148e954dbdf9b5563fe04d4e52b2.zip
整理代码
-rw-r--r--script-beta/core/rename.lua4
-rw-r--r--script-beta/vm/eachDef.lua147
-rw-r--r--script-beta/vm/eachRef.lua87
-rw-r--r--script-beta/vm/init.lua1
-rw-r--r--script-beta/vm/refOf.lua519
-rw-r--r--script-beta/vm/vm.lua1
6 files changed, 627 insertions, 132 deletions
diff --git a/script-beta/core/rename.lua b/script-beta/core/rename.lua
index 581fd19a..4beb3bbb 100644
--- a/script-beta/core/rename.lua
+++ b/script-beta/core/rename.lua
@@ -11,8 +11,8 @@ local function askForcing(str)
if TEST then
return true
end
- if Forcing == false then
- return false
+ if Forcing ~= nil then
+ return Forcing
end
local version = files.globalVersion
-- TODO
diff --git a/script-beta/vm/eachDef.lua b/script-beta/vm/eachDef.lua
index cf865384..66420d5b 100644
--- a/script-beta/vm/eachDef.lua
+++ b/script-beta/vm/eachDef.lua
@@ -1,102 +1,73 @@
-local vm = require 'vm.vm'
-local guide = require 'parser.guide'
-local files = require 'files'
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm.vm'
+local library = require 'library'
+local await = require 'await'
-local function checkPath(source, info)
- if source.type == 'goto' then
- return true
- end
- local src = info.source
- local mode = guide.getPath(source, src)
- if not mode then
- return true
- end
- if mode == 'before' then
- return false
- end
- if mode == 'equal' then
- if src.type == 'field'
- or src.type == 'method'
- or src.type == 'local'
- or src.type == 'setglobal' then
- return true
- else
- return false
- end
+local function ofLocal(declare, source, callback)
+
+end
+
+local function eachDef(source, callback)
+ local stype = source.type
+ if stype == 'local' then
+ ofLocal(source, source, callback)
+ elseif stype == 'getlocal'
+ or stype == 'setlocal' then
+ ofLocal(source.node, source, callback)
end
- return true
end
--- TODO
--- 只搜索本文件中的引用
--- 跨文件时,选确定入口(main的return),然后递归搜索本文件中的引用
--- 如果类型为setfield等,要确定tbl相同
-function vm.eachDef(source, callback)
- local results = {}
- local returns = {}
- local infoMap = {}
- local sourceUri = guide.getRoot(source).uri
- vm.eachRef(source, function (info)
- if info.mode == 'declare'
- or info.mode == 'set' then
- results[#results+1] = info
- end
- if info.mode == 'return' then
- results[#results+1] = info
- local root = guide.getParentBlock(info.source)
- if root.type == 'main' then
- returns[root.uri] = info
+--- 获取所有的引用
+function vm.eachDef(source, callback, max)
+ local cache = vm.cache.eachDef[source]
+ if cache then
+ await.delay(function ()
+ return files.globalVersion
+ end)
+ if max then
+ if max > #cache then
+ max = #cache
end
+ else
+ max = #cache
end
- infoMap[info.source] = info
- end)
-
- local function pushDef(info)
- local res = callback(info)
- if res ~= nil then
- return res
- end
- local value = info.source.value
- local vinfo = infoMap[value]
- if vinfo then
- res = callback(vinfo)
+ for i = 1, max do
+ local res = callback(cache[i])
+ if res ~= nil then
+ return res
+ end
end
- return res
+ return
end
-
- local res
- local used = {}
- for _, info in ipairs(results) do
+ local unlock = vm.lock('eachDef', source)
+ if not unlock then
+ return
+ end
+ cache = {}
+ vm.cache.eachDef[source] = cache
+ local mark = {}
+ eachDef(source, function (info)
local src = info.source
- local destUri
- if used[src] then
- goto CONTINUE
+ if mark[src] then
+ return
end
- used[src] = true
- destUri = guide.getRoot(src).uri
- -- 如果是同一个文件,则检查位置关系后放行
- if sourceUri == destUri then
- if checkPath(source, info) then
- res = pushDef(info)
- end
- goto CONTINUE
- end
- -- 如果是global或field,则直接放行(因为无法确定顺序)
- if src.type == 'setindex'
- or src.type == 'setfield'
- or src.type == 'setmethod'
- or src.type == 'tablefield'
- or src.type == 'tableindex'
- or src.type == 'setglobal' then
- res = pushDef(info)
- goto CONTINUE
- end
- -- 如果不是同一个文件,则必须在该文件 return 后才放行
- if returns[destUri] then
- res = pushDef(info)
- goto CONTINUE
+ mark[src] = true
+ cache[#cache+1] = info
+ end)
+ unlock()
+ await.delay(function ()
+ return files.globalVersion
+ end)
+ if max then
+ if max > #cache then
+ max = #cache
end
- ::CONTINUE::
+ else
+ max = #cache
+ end
+ for i = 1, max do
+ local res = callback(cache[i])
if res ~= nil then
return res
end
diff --git a/script-beta/vm/eachRef.lua b/script-beta/vm/eachRef.lua
index 81dedcbe..c629f667 100644
--- a/script-beta/vm/eachRef.lua
+++ b/script-beta/vm/eachRef.lua
@@ -524,48 +524,7 @@ function vm.isSameRef(a, b)
end
end
---- 获取所有的引用
-function vm.eachRef(source, callback, max)
- local cache = vm.cache.eachRef[source]
- if cache then
- await.delay(function ()
- return files.globalVersion
- end)
- if max then
- if max > #cache then
- max = #cache
- end
- else
- max = #cache
- end
- for i = 1, max do
- local res = callback(cache[i])
- if res ~= nil then
- return res
- end
- end
- return
- end
- local unlock = vm.lock('eachRef', source)
- if not unlock then
- return
- end
- cache = {}
- vm.cache.eachRef[source] = cache
- local mark = {}
- eachRef(source, function (info)
- local src = info.source
- if mark[src] then
- return
- end
- mark[src] = true
- cache[#cache+1] = info
- end)
- unlock()
- for i = 1, #cache do
- local src = cache[i].source
- vm.cache.eachRef[src] = cache
- end
+local function applyCache(cache, callback, max)
await.delay(function ()
return files.globalVersion
end)
@@ -583,3 +542,47 @@ function vm.eachRef(source, callback, max)
end
end
end
+
+local function eachRef(source, callback)
+ local list = { source }
+ local mark = {}
+ local result = {}
+ local state = {}
+ local function found(info)
+ local src = info.source
+ if not mark[src] then
+ list[#list+1] = src
+ end
+ mark[src] = info
+ end
+ while #list > 0 do
+ local max = #list
+ local src = list[max]
+ list[max] = nil
+ vm.refOf(state, src, found)
+ end
+ for _, info in pairs(mark) do
+ result[#result+1] = info
+ end
+ return result
+end
+
+--- 获取所有的引用
+function vm.eachRef(source, callback, max)
+ local cache = vm.cache.eachRef[source]
+ if cache then
+ applyCache(cache, callback, max)
+ return
+ end
+ local unlock = vm.lock('eachRef', source)
+ if not unlock then
+ return
+ end
+ cache = eachRef(source, callback)
+ unlock()
+ for i = 1, #cache do
+ local src = cache[i].source
+ vm.cache.eachRef[src] = cache
+ end
+ applyCache(cache, callback, max)
+end
diff --git a/script-beta/vm/init.lua b/script-beta/vm/init.lua
index a4f81d07..3645e77a 100644
--- a/script-beta/vm/init.lua
+++ b/script-beta/vm/init.lua
@@ -1,5 +1,6 @@
local vm = require 'vm.vm'
require 'vm.eachField'
+require 'vm.refOf'
require 'vm.eachRef'
require 'vm.eachDef'
require 'vm.getGlobals'
diff --git a/script-beta/vm/refOf.lua b/script-beta/vm/refOf.lua
new file mode 100644
index 00000000..96daf646
--- /dev/null
+++ b/script-beta/vm/refOf.lua
@@ -0,0 +1,519 @@
+local vm = require 'vm.vm'
+local guide = require 'parser.guide'
+local files = require 'files'
+local library = require 'library'
+
+local function ofLocal(state, loc, callback)
+ if state[loc] then
+ return
+ end
+ state[loc] = true
+ -- 方法中的 self 使用了一个虚拟的定义位置
+ if loc.tag ~= 'self' then
+ callback {
+ source = loc,
+ mode = 'declare',
+ }
+ end
+ local refs = loc.ref
+ if refs then
+ for i = 1, #refs do
+ local ref = refs[i]
+ if ref.type == 'getlocal' then
+ callback {
+ source = ref,
+ mode = 'get',
+ }
+ if loc.tag == '_ENV' then
+ local parent = ref.parent
+ if parent.type == 'getfield'
+ or parent.type == 'getindex' then
+ if guide.getKeyName(parent) == '_G' then
+ callback {
+ source = parent,
+ mode = 'get',
+ }
+ end
+ end
+ end
+ elseif ref.type == 'setlocal' then
+ callback {
+ source = ref,
+ mode = 'set',
+ }
+ if loc.tag == '_ENV' then
+ if guide.getName(ref) == '_G' then
+ callback {
+ source = ref,
+ mode = 'get',
+ }
+ end
+ end
+ end
+ end
+ end
+end
+
+local function ofGlobal(state, source, callback)
+ if state[source] then
+ return
+ end
+ local key = guide.getKeyName(source)
+ local node = source.node
+ if node.tag == '_ENV' then
+ local uris = files.findGlobals(key)
+ for i = 1, #uris do
+ local uri = uris[i]
+ local ast = files.getAst(uri)
+ local globals = vm.getGlobals(ast.ast)
+ if globals and globals[key] then
+ for _, info in ipairs(globals[key]) do
+ state[info.source] = true
+ callback(info)
+ end
+ end
+ end
+ else
+ vm.eachField(node, function (info)
+ if key == info.key then
+ state[info.source] = true
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ end
+ end)
+ end
+end
+
+local function ofField(state, source, callback)
+ if state[source] then
+ return
+ end
+ local parent = source.parent
+ local key = guide.getKeyName(source)
+ if parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ local tbl = parent.parent
+ vm.eachField(tbl, function (info)
+ if key == info.key then
+ state[info.source] = true
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ end
+ end)
+ else
+ local node = parent.node
+ vm.eachField(node, function (info)
+ if key == info.key then
+ state[info.source] = true
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ end
+ end)
+ end
+end
+
+local function ofLabel(state, source, callback)
+ if state[source] then
+ return
+ end
+ state[source] = true
+ callback {
+ source = source,
+ mode = 'set',
+ }
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback {
+ source = ref,
+ mode = 'get',
+ }
+ end
+ end
+end
+
+local function ofGoTo(state, source, callback)
+ local name = source[1]
+ local label = guide.getLabel(source, name)
+ if label then
+ ofLabel(state, label, callback)
+ end
+end
+
+local function ofValue(state, source, callback)
+ callback {
+ source = source,
+ mode = 'value',
+ }
+end
+
+local function ofIndex(state, source, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ ofField(state, source, callback)
+ end
+end
+
+local function ofCall(state, func, index, callback, offset)
+ offset = offset or 0
+ vm.eachRef(func, function (info)
+ local src = info.source
+ local returns
+ if src.type == 'main' or src.type == 'function' then
+ returns = src.returns
+ end
+ if returns then
+ -- 搜索函数第 index 个返回值
+ for i = 1, #returns do
+ local rtn = returns[i]
+ local val = rtn[index-offset]
+ if val then
+ callback {
+ source = val,
+ mode = 'return',
+ }
+ end
+ end
+ end
+ end)
+end
+
+local function ofSpecialCall(state, call, func, index, callback, offset)
+ local name = func.special
+ offset = offset or 0
+ if name == 'setmetatable' then
+ if index == 1 + offset then
+ local args = call.args
+ if args[1+offset] then
+ callback {
+ source = args[1+offset],
+ mode = 'get',
+ }
+ end
+ if args[2+offset] then
+ vm.eachField(args[2+offset], function (info)
+ if info.key == 's|__index' then
+ callback(info)
+ end
+ end)
+ end
+ vm.setMeta(args[1+offset], args[2+offset])
+ end
+ elseif name == 'require' then
+ if index == 1 + offset then
+ local result = vm.getLinkUris(call)
+ if result then
+ local myUri = guide.getRoot(call).uri
+ for i = 1, #result do
+ local uri = result[i]
+ if not files.eq(uri, myUri) then
+ local ast = files.getAst(uri)
+ if ast then
+ ofCall(state, ast.ast, 1, callback)
+ end
+ end
+ end
+ end
+
+ local args = call.args
+ if args[1+offset] then
+ if args[1+offset].type == 'string' then
+ local objName = args[1+offset][1]
+ local lib = library.library[objName]
+ if lib then
+ callback {
+ source = lib,
+ mode = 'value',
+ }
+ end
+ end
+ end
+ end
+ elseif name == 'pcall'
+ or name == 'xpcall' then
+ if index >= 2-offset then
+ local args = call.args
+ if args[1+offset] then
+ vm.eachRef(args[1+offset], function (info)
+ local src = info.source
+ if src.type == 'function' then
+ ofCall(state, src, index, callback, 1+offset)
+ ofSpecialCall(state, call, src, index, callback, 1+offset)
+ end
+ end)
+ end
+ end
+ end
+end
+
+local function ofSelect(state, source, callback)
+ -- 检查函数返回值
+ local call = source.vararg
+ if call.type == 'call' then
+ ofCall(state, call.node, source.index, callback)
+ ofSpecialCall(state, call, call.node, source.index, callback)
+ end
+end
+
+local function ofMain(state, source, callback)
+ callback {
+ source = source,
+ mode = 'main',
+ }
+end
+
+local function getCallRecvs(call)
+ local parent = call.parent
+ if parent.type ~= 'select' then
+ return nil
+ end
+ local extParent = call.extParent
+ local recvs = {}
+ recvs[1] = parent.parent
+ if extParent then
+ for i = 1, #extParent do
+ local p = extParent[i]
+ recvs[#recvs+1] = p.parent
+ end
+ end
+ return recvs
+end
+
+--- 自己作为函数的参数
+local function checkAsArg(state, source, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'callargs' then
+ local call = parent.parent
+ local func = call.node
+ local name = func.special
+ if name == 'setmetatable' then
+ if parent[1] == source then
+ if parent[2] then
+ vm.eachField(parent[2], function (info)
+ if info.key == 's|__index' then
+ callback(info)
+ end
+ end)
+ end
+ local recvs = getCallRecvs(call)
+ if recvs and recvs[1] then
+ callback {
+ source = recvs[1],
+ mode = 'return',
+ }
+ end
+ vm.setMeta(source, parent[2])
+ end
+ end
+ end
+end
+
+local function ofCallSelect(state, call, index, callback)
+ local slc = call.parent
+ if slc.index == index then
+ callback {
+ source = slc.parent,
+ mode = 'get',
+ }
+ return
+ end
+ if call.extParent then
+ for i = 1, #call.extParent do
+ slc = call.extParent[i]
+ if slc.index == index then
+ callback {
+ source = slc.parent,
+ mode = 'get',
+ }
+ return
+ end
+ end
+ end
+end
+
+--- 自己作为函数的返回值
+local function checkAsReturn(state, source, callback)
+ local parent = source.parent
+ if source.type == 'field'
+ or source.type == 'method' then
+ parent = parent.parent
+ end
+ if not parent or parent.type ~= 'return' then
+ return
+ end
+ local func = guide.getParentFunction(source)
+ if func.type == 'main' then
+ local myUri = func.uri
+ local uris = files.findLinkTo(myUri)
+ if not uris then
+ return
+ end
+ for i = 1, #uris do
+ local uri = uris[i]
+ local ast = files.getAst(uri)
+ if ast then
+ local links = vm.getLinks(ast.ast)
+ if links then
+ for linkUri, calls in pairs(links) do
+ if files.eq(linkUri, myUri) then
+ for j = 1, #calls do
+ ofCallSelect(state, calls[j], 1, callback)
+ end
+ end
+ end
+ end
+ end
+ end
+ else
+ local index
+ for i = 1, #parent do
+ if parent[i] == source then
+ index = i
+ break
+ end
+ end
+ if not index then
+ return
+ end
+ vm.eachRef(func, function (info)
+ local src = info.source
+ local call = src.parent
+ if not call or call.type ~= 'call' then
+ return
+ end
+ local recvs = getCallRecvs(call)
+ if recvs and recvs[index] then
+ callback {
+ source = recvs[index],
+ mode = 'return',
+ }
+ elseif index == 1 then
+ callback {
+ type = 'call',
+ source = call,
+ }
+ end
+ end)
+ end
+end
+
+local function checkAsParen(state, source, callback)
+ if state[source] then
+ return
+ end
+ state[source] = true
+ if source.parent and source.parent.type == 'paren' then
+ vm.refOf(state, source.parent, callback)
+ end
+end
+
+local function checkValue(state, source, callback)
+ if source.value then
+ callback {
+ source = source.value,
+ mode = 'value',
+ }
+ end
+end
+
+local function checkSetValue(value, callback)
+ if value.type == 'field'
+ or value.type == 'method' then
+ value = value.parent
+ end
+ local parent = value.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'local'
+ or parent.type == 'setglobal'
+ or parent.type == 'setlocal'
+ or parent.type == 'setfield'
+ or parent.type == 'setmethod'
+ or parent.type == 'setindex'
+ or parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ if parent.value == value then
+ callback {
+ source = parent,
+ mode = 'set',
+ }
+ if guide.getName(parent) == '__index' then
+ if parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ local t = parent.parent
+ local args = t.parent
+ if args[2] == t then
+ local call = args.parent
+ local func = call.node
+ if func.special == 'setmetatable' then
+ callback {
+ source = args[1],
+ mode = 'get',
+ }
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+function vm.refOf(state, source, callback)
+ local stype = source.type
+ if stype == 'local' then
+ ofLocal(state, source, callback)
+ elseif stype == 'getlocal'
+ or stype == 'setlocal' then
+ ofLocal(state, source.node, callback)
+ elseif stype == 'setglobal'
+ or stype == 'getglobal' then
+ ofGlobal(state, source, callback)
+ elseif stype == 'field'
+ or stype == 'method' then
+ ofField(state, source, callback)
+ elseif stype == 'setfield'
+ or stype == 'getfield'
+ or stype == 'tablefield' then
+ ofField(state, source.field, callback)
+ elseif stype == 'setmethod'
+ or stype == 'getmethod' then
+ ofField(state, source.method, callback)
+ elseif stype == 'goto' then
+ ofGoTo(state, source, callback)
+ elseif stype == 'label' then
+ ofLabel(state, source, callback)
+ elseif stype == 'number'
+ or stype == 'boolean'
+ or stype == 'string' then
+ ofIndex(state, source, callback)
+ ofValue(state, source, callback)
+ elseif stype == 'table'
+ or stype == 'function' then
+ ofValue(state, source, callback)
+ elseif stype == 'select' then
+ ofSelect(state, source, callback)
+ elseif stype == 'main' then
+ ofMain(state, source, callback)
+ elseif stype == 'paren' then
+ vm.refOf(state, source.exp, callback)
+ end
+ checkValue(state, source, callback)
+ checkAsArg(state, source, callback)
+ checkAsReturn(state, source, callback)
+ checkAsParen(state, source, callback)
+ checkSetValue(state, source, callback)
+end
diff --git a/script-beta/vm/vm.lua b/script-beta/vm/vm.lua
index 5460c52b..06ffc172 100644
--- a/script-beta/vm/vm.lua
+++ b/script-beta/vm/vm.lua
@@ -68,6 +68,7 @@ function m.refreshCache()
end
m.cache = {
eachRef = {},
+ eachDef = {},
eachField = {},
eachMeta = {},
getGlobals = {},