diff options
Diffstat (limited to 'server/src')
-rw-r--r-- | server/src/matcher/definition.lua | 32 | ||||
-rw-r--r-- | server/src/matcher/vm.lua | 117 | ||||
-rw-r--r-- | server/src/method/textDocument/definition.lua | 6 | ||||
-rw-r--r-- | server/src/service.lua | 6 |
4 files changed, 88 insertions, 73 deletions
diff --git a/server/src/matcher/definition.lua b/server/src/matcher/definition.lua index 3dc4f8ff..57e4b237 100644 --- a/server/src/matcher/definition.lua +++ b/server/src/matcher/definition.lua @@ -1,25 +1,5 @@ local findResult = require 'matcher.find_result' -local function tryMeta(var) - local keys = {} - repeat - if var.childs.meta then - local metavar = var.childs.meta - for i = #keys, 1, -1 do - local key = keys[i] - metavar = metavar.childs[key] - if not metavar then - return nil - end - end - return metavar - end - keys[#keys+1] = var.key - var = var.parent - until not var - return nil -end - local function parseResult(result) local positions = {} local tp = result.type @@ -35,14 +15,6 @@ local function parseResult(result) positions[#positions+1] = {info.source.start, info.source.finish} end end - --local metavar = tryMeta(var) - --if metavar then - -- for _, info in ipairs(metavar) do - -- if info.type == 'set' then - -- positions[#positions+1] = {info.source.start, --info.source.finish} - -- end - -- end - --end elseif tp == 'label' then for _, info in ipairs(result.object) do if info.type == 'set' then @@ -55,8 +27,8 @@ local function parseResult(result) return positions end -return function (results, pos) - local result = findResult(results, pos) +return function (vm, pos) + local result = findResult(vm.results, pos) if not result then return nil end diff --git a/server/src/matcher/vm.lua b/server/src/matcher/vm.lua index 47c6b44a..7f99d087 100644 --- a/server/src/matcher/vm.lua +++ b/server/src/matcher/vm.lua @@ -20,7 +20,7 @@ end function mt:addInfo(obj, type, source) if source and not source.start then - error('Miss start') + error('Miss start: ' .. table.dump(source)) end obj[#obj+1] = { type = type, @@ -54,37 +54,65 @@ function mt:createTable(source) if key.index then local index = self:getIndex(key) local field = self:createField(tbl, index, key) - self:setValue(field, value) + if value.type == 'list' then + self:setValue(field, value[1]) + else + self:setValue(field, value) + end self:addInfo(field, 'set', key) else if key.type == 'name' then local index = key[1] local field = self:createField(tbl, index, key) - self:setValue(field, value) + if value.type == 'list' then + self:setValue(field, value[1]) + else + self:setValue(field, value) + end self:addInfo(field, 'set', key) end end else local value = self:getExp(obj) - n = n + 1 - local field = self:createField(tbl, n) - self:setValue(field, value) + if value.type == 'list' then + for i, v in ipairs(value) do + local field = self:createField(tbl, n + i) + self:setValue(field, v) + end + break + else + n = n + 1 + local field = self:createField(tbl, n) + self:setValue(field, value) + end end end return tbl end function mt:coverValue(target, source) + local child = target.child for k in pairs(target) do target[k] = nil end for k, v in pairs(source) do target[k] = v end + if child then + if not target.child then + target.child = {} + end + for k, v in pairs(child) do + if target.child[k] == nil then + target.child[k] = v + end + end + end end function mt:setValue(var, value) assert(not value or value.type ~= 'list') + value = value or self:createNil(var) if var.value then if var.value.type == 'nil' then -- 允许覆盖nil @@ -93,7 +121,7 @@ function mt:setValue(var, value) end end else - var.value = value or self:createNil(var) + var.value = value end return value end @@ -106,6 +134,7 @@ function mt:getValue(var) end function mt:createField(pValue, name, source) + assert(pValue.type ~= 'local' and pValue.type ~= 'field') local field = { type = 'field', key = name, @@ -192,41 +221,50 @@ function mt:forList(list, callback) end end -function mt:call(func, values) - if func.used then - return self:getFunctionReturns(func) - end - func.used = true - +function mt:setFunctionArgs(func, values) if not func.args then - return self:getFunctionReturns(func) + return end for i, var in ipairs(func.args) do - if var then - if var.type == 'dots' then - local list = { - type = 'list', - } - if values then - for n = i, #values do - list[n-i+1] = values[n] - end - self:setValue(var, list) - else - self:setValue(var, nil) + if var.type == 'dots' then + local list = { + type = 'list', + } + if values then + for n = i, #values do + list[n-i+1] = values[n] end - break + self:setValue(var, list) else - if values then - self:setValue(var, values[i]) - else - self:setValue(var, nil) - end + self:setValue(var, nil) end + break + else + if values then + self:setValue(var, values[i]) + else + self:setValue(var, nil) + end + end + end +end + +function mt:callSetMetaTable(values) + values[1].metatable = values[2] + self:setFunctionReturn(self:getCurrentFunction(), 1, values[1]) +end + +function mt:call(func, values) + local lib = func.lib + if lib and lib.special then + if lib.special == 'setmetatable' then + self:callSetMetaTable(values) end end + self:setFunctionArgs(func, values) + return self:getFunctionReturns(func) end @@ -314,7 +352,7 @@ function mt:getLibValue(lib) for i, arg in ipairs(lib.args) do values[i] = self:getLibValue(arg) or self:createNil() end - self:call(value, values) + self:setFunctionArgs(value, values) end elseif tp == 'string' then value = self:createString(lib.value) @@ -334,7 +372,7 @@ end function mt:getName(name, source) local var = self.scope.locals[name] - or self:getField(self.scope.locals._ENV, name, source) + or self:getField(self:getValue(self.scope.locals._ENV), name, source) return var end @@ -408,7 +446,9 @@ function mt:getSimple(simple, mode) local index = self:getIndex(obj) field = self:getField(value, index, obj) if mode == 'value' or i < #simple then - self:addInfo(field, 'get', obj) + if obj.start then + self:addInfo(field, 'get', obj) + end end value = self:getValue(field) else @@ -561,7 +601,10 @@ function mt:doSet(action) elseif key.type == 'simple' then local field = self:getSimple(key, 'field') self:setValue(field, value) - self:addInfo(field, 'set', key[#key]) + local source = key[#key] + if source.start then + self:addInfo(field, 'set', source) + end end end) end @@ -775,7 +818,7 @@ local function compile(ast) -- 执行代码 vm:doActions(ast) - return vm.results + return vm end return function (ast) diff --git a/server/src/method/textDocument/definition.lua b/server/src/method/textDocument/definition.lua index d44ec1f6..a40c1832 100644 --- a/server/src/method/textDocument/definition.lua +++ b/server/src/method/textDocument/definition.lua @@ -2,13 +2,13 @@ local matcher = require 'matcher' return function (lsp, params) local uri = params.textDocument.uri - local results, lines = lsp:loadText(uri) - if not results then + local vm, lines = lsp:loadText(uri) + if not vm then return {} end -- lua是从1开始的,因此都要+1 local position = lines:position(params.position.line + 1, params.position.character + 1) - local positions = matcher.definition(results, position) + local positions = matcher.definition(vm, position) if not positions then return {} end diff --git a/server/src/service.lua b/server/src/service.lua index 109a8c1c..6e753b39 100644 --- a/server/src/service.lua +++ b/server/src/service.lua @@ -199,15 +199,15 @@ function mt:compileText(uri) end self._needCompile[uri] = nil local ast = parser:ast(obj.text) - obj.results = matcher.compile(ast) - if not obj.results then + obj.vm = matcher.vm(ast) + if not obj.vm then return obj end obj.lines = parser:lines(obj.text, 'utf8') self._needDiagnostics[uri] = { ast = ast, - results = obj.results, + vm = obj.vm, lines = obj.lines, uri = uri, } |