From 91af49cdfd292b7e479e6e0327d04c436256ed7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Mon, 11 Nov 2019 19:31:04 +0800 Subject: =?UTF-8?q?=E8=BF=87=E5=8D=95=E6=96=87=E4=BB=B6=20reference=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server-beta/src/core/reference.lua | 25 +++++++++++++++++++++--- server-beta/src/files.lua | 2 +- server-beta/src/parser/ast.lua | 2 +- server-beta/src/searcher/eachRef.lua | 37 ++++++++++++++++++++++++++++++++++++ server-beta/test/references/init.lua | 19 +++++++++++++----- 5 files changed, 75 insertions(+), 10 deletions(-) diff --git a/server-beta/src/core/reference.lua b/server-beta/src/core/reference.lua index 092b6218..4c47a072 100644 --- a/server-beta/src/core/reference.lua +++ b/server-beta/src/core/reference.lua @@ -3,7 +3,15 @@ local workspace = require 'workspace' local files = require 'files' local searcher = require 'searcher' -local function findDef(source, callback) +local function isFunction(source, offset) + if source.type ~= 'function' then + return false + end + -- 必须点在 `function` 这个单词上才能查找函数引用 + return offset >= source.start and offset < source.start + #'function' +end + +local function findDef(source, offset, callback) if source.type ~= 'local' and source.type ~= 'getlocal' and source.type ~= 'setlocal' @@ -15,7 +23,8 @@ local function findDef(source, callback) and source.type ~= 'string' and source.type ~= 'number' and source.type ~= 'boolean' - and source.type ~= 'goto' then + and source.type ~= 'goto' + and not isFunction(source, offset) then return end searcher.eachRef(source, function (info) @@ -41,6 +50,16 @@ local function findDef(source, callback) callback(src, uri) end end + if info.mode == 'value' then + local src = info.source + local root = guide.getRoot(src) + local uri = root.uri + if src.type == 'function' then + if src.parent.type == 'return' then + callback(src, uri) + end + end + end end) end @@ -51,7 +70,7 @@ return function (uri, offset) end local results = {} guide.eachSourceContain(ast.ast, offset, function (source) - findDef(source, function (target, uri) + findDef(source, offset, function (target, uri) results[#results+1] = { target = target, uri = files.getOriginUri(uri), diff --git a/server-beta/src/files.lua b/server-beta/src/files.lua index ad2d9552..0f3d721f 100644 --- a/server-beta/src/files.lua +++ b/server-beta/src/files.lua @@ -223,9 +223,9 @@ function m.findGlobals(name) local uris = {} for uri, file in pairs(m.fileMap) do if not file.globals then + file.globals = {} local ast = m.getAst(uri) if ast then - file.globals = {} local globals = searcher.getGlobals(ast.ast) for name in pairs(globals) do file.globals[name] = true diff --git a/server-beta/src/parser/ast.lua b/server-beta/src/parser/ast.lua index a54ef937..a87d9acc 100644 --- a/server-beta/src/parser/ast.lua +++ b/server-beta/src/parser/ast.lua @@ -706,7 +706,7 @@ local Defs = { type = 'binary', op = op, start = left.start, - finish = right.finish, + finish = right and right.finish or op.finish, [1] = left, [2] = right, } diff --git a/server-beta/src/searcher/eachRef.lua b/server-beta/src/searcher/eachRef.lua index 61aab081..e5877ae7 100644 --- a/server-beta/src/searcher/eachRef.lua +++ b/server-beta/src/searcher/eachRef.lua @@ -23,6 +23,35 @@ local function ofCall(func, index, callback) end) end +local function ofReturn(rtn, index, callback) + local func = guide.getParentFunction(rtn) + if not func then + return + end + -- 搜索函数调用的第 index 个接收值 + searcher.eachRef(func, function (info) + local source = info.source + local call = source.parent + if not call or call.type ~= 'call' then + return + end + local slc = call.parent + if slc.index == index then + searcher.eachRef(slc.parent, callback) + return + end + if call.extParent then + for i = 1, #call.extParent do + slc = call.extParent[i] + if slc.index == index then + searcher.eachRef(slc.parent, callback) + return + end + end + end + end) +end + local function ofSpecialCall(call, func, index, callback) local name = searcher.getSpecialName(func) if name == 'setmetatable' then @@ -100,6 +129,14 @@ local function ofValue(value, callback) searcher.eachRef(parent, callback) end end + if parent.type == 'return' then + for i = 1, #parent do + if parent[i] == value then + ofReturn(parent, i, callback) + break + end + end + end end local function ofSelf(loc, callback) diff --git a/server-beta/test/references/init.lua b/server-beta/test/references/init.lua index e66449fb..e009ee19 100644 --- a/server-beta/test/references/init.lua +++ b/server-beta/test/references/init.lua @@ -34,10 +34,10 @@ end function TEST(script) files.removeAll() local target = catch_target(script) - local start = script:find('', 1, true) + local start = script:find('<[?~]') + local finish = script:find('[?~]>') local pos = (start + finish) // 2 + 1 - local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ') + local new_script = script:gsub('<[!?~]', ' '):gsub('[!?~]>', ' ') files.setText('', new_script) local results = core('', pos) @@ -92,13 +92,22 @@ end TEST [[ local function f() - return + return <~ () + end!> end local = f() ]] +TEST [[ +local function f() + return nil, <~ () + end!> +end + +local _, = f() +]] + TEST [[ table.() function table.() -- cgit v1.2.3