From 3f9ca58f22eeccd30ff1093c9874bfa79d40a29f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Mon, 24 Dec 2018 14:08:22 +0800 Subject: =?UTF-8?q?=E6=8D=A2=E4=B8=AArequire=E7=9A=84=E5=81=9A=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/src/matcher/vm.lua | 204 +++++++++++++---------------------- server/src/service.lua | 56 ++++------ server/test/crossfile/definition.lua | 10 +- server/test/crossfile/hover.lua | 28 +++++ server/test/hover/init.lua | 18 ++++ 5 files changed, 146 insertions(+), 170 deletions(-) diff --git a/server/src/matcher/vm.lua b/server/src/matcher/vm.lua index de55c12e..74a86d59 100644 --- a/server/src/matcher/vm.lua +++ b/server/src/matcher/vm.lua @@ -455,6 +455,77 @@ function mt:callSetMetaTable(func, values) self:checkMetaIndex(values[1], values[2]) end +function mt:getRequire(strValue, destVM) + -- 取出对方的主函数 + local main = destVM.results.main + -- 获取主函数返回值,注意不能修改对方的环境 + local mainValue + if main.returns then + mainValue = main.returns[1] + else + mainValue = self:createValue('boolean', { + type = 'name', + start = 0, + finish = 0, + [1] = '', + uri = destVM.uri, + }, true) + end + + -- 支持 require 'xxx' 的转到定义 + local strSource = strValue.source + self.results.sources[strSource] = strValue + strValue.uri = destVM.uri + + return mainValue +end + +function mt:getLoadFile(strValue, destVM) + -- 取出对方的主函数 + local main = destVM.results.main + -- loadfile 的返回值就是对方的主函数 + local mainValue = main + + -- 支持 loadfile 'xxx.lua' 的转到定义 + local strSource = strValue.source + self.results.sources[strSource] = strValue + strValue.uri = destVM.uri + + return mainValue +end + +function mt:tryRequireOne(strValue, mode) + if not self.lsp or not self.lsp.workspace then + return nil + end + local str = strValue.value + if type(str) == 'string' then + local uri + if mode == 'require' then + uri = self.lsp.workspace:searchPath(self.uri, str) + elseif mode == 'loadfile' then + uri = self.lsp.workspace:loadPath(self.uri, str) + elseif mode == 'dofile' then + uri = self.lsp.workspace:loadPath(self.uri, str) + end + -- 如果取不到VM(不编译),则做个标记,之后再取一次 + local destVM = self.lsp:getVM(uri) + if destVM then + if mode == 'require' then + return self:getRequire(strValue, destVM) + elseif mode == 'loadfile' then + return self:getLoadFile(strValue, destVM) + elseif mode == 'dofile' then + return self:getRequire(strValue, destVM) + end + else + self.lsp:needCompile(uri) + self.lsp:needCompile(self.uri) + end + end + return nil +end + function mt:callRequire(func, values) if not values[1] then values[1] = self:createValue('any') @@ -469,12 +540,8 @@ function mt:callRequire(func, values) self:setFunctionReturn(func, 1, value) return else - local requireValue = self:createValue('boolean', nil, true) + local requireValue = self:tryRequireOne(values[1], 'require') or self:createValue('boolean') self:setFunctionReturn(func, 1, requireValue) - self.requires[requireValue] = { - mode = 'require', - str = values[1], - } end end @@ -486,12 +553,8 @@ function mt:callLoadFile(func, values) if type(str) ~= 'string' then return end - local requireValue = self:buildFunction() + local requireValue = self:tryRequireOne(values[1], 'loadfile') self:setFunctionReturn(func, 1, requireValue) - self.requires[requireValue] = { - mode = 'loadfile', - str = values[1], - } end function mt:callDoFile(func, values) @@ -502,12 +565,8 @@ function mt:callDoFile(func, values) if type(str) ~= 'string' then return end - local requireValue = self:createValue('any') + local requireValue = self:tryRequireOne(values[1], 'dofile') self:setFunctionReturn(func, 1, requireValue) - self.requires[requireValue] = { - mode = 'dofile', - str = values[1], - } end function mt:call(func, values) @@ -1242,117 +1301,6 @@ function mt:createEnvironment() gValue.child = envValue.child end -function mt:mergeRequire(value, strValue, destVM) - -- 取出对方的主函数 - local main = destVM.results.main - -- 获取主函数返回值,注意不能修改对方的环境 - local mainValue - if not main.returns then - mainValue = self:createValue('nil', { - type = 'name', - start = 0, - finish = 0, - [1] = '', - uri = destVM.uri, - }) - else - mainValue = main.returns[1] - end - self:mergeValue(value, mainValue) - - -- 支持 require 'xxx' 的转到定义 - local strSource = strValue.source - self.results.sources[strSource] = strValue - strValue.uri = destVM.uri -end - -function mt:mergeLoadFile(value, strValue, destVM) - -- 取出对方的主函数 - local main = destVM.results.main - -- loadfile 的返回值就是对方的主函数 - local mainValue = main - self:mergeValue(value, mainValue) - - -- 支持 loadfile 'xxx.lua' 的转到定义 - local strSource = strValue.source - self.results.sources[strSource] = strValue - strValue.uri = destVM.uri -end - -function mt:loadRequires() - if not self.lsp or not self.lsp.workspace then - return - end - local copy = {} - for k, v in pairs(self.requires) do - self.requires[k] = nil - copy[k] = v - end - for value, data in pairs(copy) do - local strValue = data.str - local mode = data.mode - local str = strValue.value - if type(str) == 'string' then - local uri - if mode == 'require' then - uri = self.lsp.workspace:searchPath(self.uri, str) - elseif mode == 'loadfile' then - uri = self.lsp.workspace:loadPath(self.uri, str) - elseif mode == 'dofile' then - uri = self.lsp.workspace:loadPath(self.uri, str) - elseif mode == '' then - end - -- 如果循环require,这里会返回nil - -- 会当场编译VM - local destVM = self.lsp:loadVM(uri) - if destVM then - if mode == 'require' then - self:mergeRequire(value, strValue, destVM) - elseif mode == 'loadfile' then - self:mergeLoadFile(value, strValue, destVM) - elseif mode == 'dofile' then - self:mergeRequire(value, strValue, destVM) - end - end - end - end -end - -function mt:tryLoadRequires() - if not self.lsp or not self.lsp.workspace then - return - end - for value, data in pairs(self.requires) do - local strValue = data.str - local mode = data.mode - local str = strValue.value - if type(str) == 'string' then - local uri - if mode == 'require' then - uri = self.lsp.workspace:searchPath(self.uri, str) - elseif mode == 'loadfile' then - uri = self.lsp.workspace:loadPath(self.uri, str) - elseif mode == 'dofile' then - uri = self.lsp.workspace:loadPath(self.uri, str) - end - -- 如果取不到VM(不编译),则做个标记,之后再取一次 - local destVM = self.lsp:getVM(uri) - if destVM then - if mode == 'require' then - self:mergeRequire(value, strValue, destVM) - elseif mode == 'loadfile' then - self:mergeLoadFile(value, strValue, destVM) - elseif mode == 'dofile' then - self:mergeRequire(value, strValue, destVM) - end - self.requires[value] = nil - else - self.lsp:needRequires(self.uri) - end - end - end -end - local function compile(ast, lsp, uri) local vm = setmetatable({ scope = env { @@ -1372,7 +1320,6 @@ local function compile(ast, lsp, uri) }, libraryValue = {}, libraryChild = {}, - requires = {}, lsp = lsp, uri = uri, }, mt) @@ -1383,9 +1330,6 @@ local function compile(ast, lsp, uri) -- 执行代码 vm:doActions(ast) - -- 合并 - vm:tryLoadRequires() - return vm end diff --git a/server/src/service.lua b/server/src/service.lua index 12a6ed25..0c167570 100644 --- a/server/src/service.lua +++ b/server/src/service.lua @@ -131,13 +131,13 @@ function mt:clearDiagnostics(uri) }) end -function mt:_buildTextCache() +function mt:compileAll() if not next(self._needCompile) then return end local list = {} - for uri in pairs(self._needCompile) do - list[#list+1] = uri + for i, uri in ipairs(self._needCompile) do + list[i] = uri end local size = 0 @@ -167,6 +167,14 @@ function mt:read(mode) return self._input(mode) end +function mt:needCompile(uri) + if self._needCompile[uri] then + return + end + self._needCompile[uri] = true + self._needCompile[#self._needCompile+1] = uri +end + function mt:saveText(uri, version, text) local obj = self._file[uri] if obj then @@ -175,13 +183,13 @@ function mt:saveText(uri, version, text) end obj.version = version obj.text = text - self._needCompile[uri] = true + self:needCompile(uri) else self._file[uri] = { version = version, text = text, } - self._needCompile[uri] = true + self:needCompile(uri) end end @@ -198,7 +206,7 @@ function mt:readText(uri, path) version = -1, text = text, } - self._needCompile[uri] = true + self:needCompile(uri) end function mt:open(uri) @@ -215,7 +223,7 @@ end function mt:reCompile() for uri in pairs(self._opening) do - self._needCompile[uri] = true + self:needCompile(uri) end end @@ -237,6 +245,11 @@ function mt:compileVM(uri) return nil end self._needCompile[uri] = nil + for i, u in ipairs(self._needCompile) do + if u == uri then + table.remove(self._needCompile, i) + end + end local ast = parser:ast(obj.text) obj.vm = matcher.vm(ast, self, uri) @@ -261,31 +274,6 @@ function mt:getVM(uri) return obj.vm end -function mt:needRequires(uri) - self._needRequire[uri] = true -end - -function mt:_loadRequires() - if not self.workspace then - return - end - if not next(self._needRequire) then - return - end - local copy = {} - for uri in pairs(self._needRequire) do - self._needRequire[uri] = nil - copy[uri] = true - end - for uri in pairs(copy) do - local obj = self._file[uri] - if obj then - obj.vm:loadRequires() - self._needDiagnostics[uri] = true - end - end -end - function mt:removeText(uri) self._file[uri] = nil end @@ -302,9 +290,8 @@ function mt:onTick() rpc:recieve(proto) end end - self:_buildTextCache() + self:compileAll() self:_doDiagnostic() - self:_loadRequires() if os.clock() - self._clock >= 600 then self._clock = os.clock() @@ -344,7 +331,6 @@ return function () local session = setmetatable({ _file = {}, _needCompile = {}, - _needRequire = {}, _needDiagnostics = {}, _opening = {}, _clock = -100, diff --git a/server/test/crossfile/definition.lua b/server/test/crossfile/definition.lua index 43261771..11d5f746 100644 --- a/server/test/crossfile/definition.lua +++ b/server/test/crossfile/definition.lua @@ -38,14 +38,14 @@ function TEST(data) local sourceScript, sourceList = catch_target(data[2].content, '?') local sourceUri = ws:uriEncode(fs.path(data[2].path)) - lsp:saveText(targetUri, 1, targetScript) lsp:saveText(sourceUri, 1, sourceScript) - ws:addFile(targetUri) ws:addFile(sourceUri) - lsp:compileVM(targetUri) - lsp:compileVM(sourceUri) + lsp:saveText(targetUri, 1, targetScript) + ws:addFile(targetUri) + lsp:compileAll() + lsp:compileAll() - local sourceVM = lsp:loadVM(sourceUri) + local sourceVM = lsp:getVM(sourceUri) assert(sourceVM) local sourcePos = (sourceList[1][1] + sourceList[1][2]) // 2 local positions = matcher.definition(sourceVM, sourcePos) diff --git a/server/test/crossfile/hover.lua b/server/test/crossfile/hover.lua index 8bcc9411..51ac85bd 100644 --- a/server/test/crossfile/hover.lua +++ b/server/test/crossfile/hover.lua @@ -135,3 +135,31 @@ TEST { label = 'function (a: any, b: any)', } } + +TEST { + { + path = 'a.lua', + content = [[ + local mt = {} + mt.__index = mt + + function mt:add(a, b) + end + + return function () + return setmetatable({}, mt) + end + ]], + }, + { + path = 'b.lua', + content = [[ + local m = require 'a' + local obj = m() + obj:() + ]] + }, + hover = { + label = 'function mt:add(a: any, b: any)' + }, +} diff --git a/server/test/hover/init.lua b/server/test/hover/init.lua index 5d0cb068..445e72ba 100644 --- a/server/test/hover/init.lua +++ b/server/test/hover/init.lua @@ -198,3 +198,21 @@ end function x() -> any ]] + +TEST [[ +local mt = {} +mt.__index = mt + +function mt:add(a, b) +end + +local function init() + return setmetatable({}, mt) +end + +local t = init() +t:() +]] +[[ +function mt:add(a: any, b: any) +]] -- cgit v1.2.3