summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/matcher/definition.lua14
-rw-r--r--src/matcher/implementation.lua280
-rw-r--r--src/matcher/init.lua3
-rw-r--r--src/method/init.lua1
-rw-r--r--src/method/initialize.lua2
-rw-r--r--src/method/textDocument/implementation.lua46
-rw-r--r--test/definition/set.lua26
-rw-r--r--test/implementation/arg.lua23
-rw-r--r--test/implementation/bug.lua15
-rw-r--r--test/implementation/function.lua24
-rw-r--r--test/implementation/init.lua22
-rw-r--r--test/implementation/local.lua191
-rw-r--r--test/implementation/set.lua31
-rw-r--r--test/implementation/table.lua6
-rw-r--r--test/main.lua1
15 files changed, 680 insertions, 5 deletions
diff --git a/src/matcher/definition.lua b/src/matcher/definition.lua
index 04d01135..a744efed 100644
--- a/src/matcher/definition.lua
+++ b/src/matcher/definition.lua
@@ -32,8 +32,16 @@ end
local function scopeSet(obj)
local name = obj[1]
+ local scope = scopes[#scopes]
+ if not scope[name] then
+ scope[name] = obj
+ end
+end
+
+local function globalSet(obj)
+ local name = obj[1]
if not scopeGet(name) then
- local scope = scopes[#scopes]
+ local scope = scopes[1]
scope[name] = obj
end
end
@@ -97,7 +105,7 @@ function defs.Set(simples)
if simple.type == 'simple' and #simple == 1 then
local obj = simple[1]
local name = obj[1]
- scopeSet(obj)
+ globalSet(obj)
end
end
end
@@ -122,7 +130,7 @@ end
function defs.FunctionDef(simple, args)
if #simple == 1 then
- scopeSet(simple[1])
+ globalSet(simple[1])
end
scopePush()
-- 判断隐藏的局部变量self
diff --git a/src/matcher/implementation.lua b/src/matcher/implementation.lua
new file mode 100644
index 00000000..e245ef90
--- /dev/null
+++ b/src/matcher/implementation.lua
@@ -0,0 +1,280 @@
+local parser = require 'parser'
+
+local pos
+local defs = {}
+local scopes
+local result
+local namePos
+local colonPos
+
+local DUMMY_TABLE = {}
+
+local function scopeInit()
+ scopes = {{}}
+end
+
+local function scopeGet(name)
+ for i = #scopes, 1, -1 do
+ local scope = scopes[i]
+ local obj = scope[name]
+ if obj then
+ return obj
+ end
+ end
+ return nil
+end
+
+local function scopeSet(obj)
+ local name = obj[1]
+ local scope = scopes[#scopes]
+ scope[name] = obj
+end
+
+local function globalSet(obj)
+ local name = obj[1]
+ for i = #scopes, 1, -1 do
+ local scope = scopes[i]
+ local old = scope[name]
+ if old then
+ scope[name] = obj
+ return
+ end
+ end
+ local scope = scopes[1]
+ scope[name] = obj
+end
+
+local function scopePush()
+ scopes[#scopes+1] = {}
+end
+
+local function scopePop()
+ scopes[#scopes] = nil
+end
+
+local function checkImplementation(name, p)
+ if pos < p or pos > p + #name then
+ return
+ end
+ result = scopeGet(name)
+end
+
+function defs.NamePos(p)
+ namePos = p
+end
+
+function defs.Name(str)
+ checkImplementation(str, namePos)
+ return {str, namePos, type = 'name'}
+end
+
+function defs.DOTSPos(p)
+ namePos = p
+end
+
+function defs.DOTS(str)
+ checkImplementation(str, namePos)
+ return {str, namePos, type = 'name'}
+end
+
+function defs.COLONPos(p)
+ colonPos = p
+end
+
+function defs.ColonName(name)
+ name.colon = colonPos
+ return name
+end
+
+function defs.LocalVar(names)
+ for _, name in ipairs(names) do
+ scopeSet(name)
+ end
+end
+
+function defs.LocalSet(names)
+ for _, name in ipairs(names) do
+ scopeSet(name)
+ end
+end
+
+function defs.Set(simples)
+ for _, simple in ipairs(simples) do
+ if simple.type == 'simple' and #simple == 1 then
+ local obj = simple[1]
+ local name = obj[1]
+ globalSet(obj)
+ end
+ end
+end
+
+function defs.Simple(...)
+ return { type = 'simple', ... }
+end
+
+function defs.ArgList(...)
+ if ... == '' then
+ return DUMMY_TABLE
+ end
+ return { type = 'list', ... }
+end
+
+function defs.FuncName(...)
+ if ... == '' then
+ return DUMMY_TABLE
+ end
+ return { type = 'simple', ... }
+end
+
+function defs.FunctionDef(simple, args)
+ if #simple == 1 then
+ globalSet(simple[1])
+ end
+ scopePush()
+ -- 判断隐藏的局部变量self
+ if #simple > 0 then
+ local name = simple[#simple]
+ if name.colon then
+ scopeSet {'self', name.colon, name.colon, type = 'name'}
+ end
+ end
+ for _, arg in ipairs(args) do
+ if arg.type == 'simple' and #arg == 1 then
+ local name = arg[1]
+ scopeSet(name)
+ end
+ if arg.type == 'name' then
+ scopeSet(arg)
+ end
+ end
+end
+
+function defs.FunctionLoc(simple, args)
+ if #simple == 1 then
+ scopeSet(simple[1])
+ end
+ scopePush()
+ -- 判断隐藏的局部变量self
+ if #simple > 0 then
+ local name = simple[#simple]
+ if name.colon then
+ scopeSet {'self', name.colon, name.colon, type = 'name'}
+ end
+ end
+ for _, arg in ipairs(args) do
+ if arg.type == 'simple' and #arg == 1 then
+ local name = arg[1]
+ scopeSet(name)
+ end
+ if arg.type == 'name' then
+ scopeSet(arg)
+ end
+ end
+end
+
+function defs.Function()
+ scopePop()
+end
+
+function defs.DoDef()
+ scopePush()
+end
+
+function defs.Do()
+ scopePop()
+end
+
+function defs.IfDef()
+ scopePush()
+end
+
+function defs.If()
+ scopePop()
+end
+
+function defs.ElseIfDef()
+ scopePush()
+end
+
+function defs.ElseIf()
+ scopePop()
+end
+
+function defs.ElseDef()
+ scopePush()
+end
+
+function defs.Else()
+ scopePop()
+end
+
+function defs.LoopDef(name)
+ scopePush()
+ scopeSet(name)
+end
+
+function defs.Loop()
+ scopePop()
+end
+
+function defs.LoopStart(name, exp)
+ return name
+end
+
+function defs.NameList(...)
+ return { type = 'list', ... }
+end
+
+function defs.SimpleList(...)
+ return { type = 'list', ... }
+end
+
+function defs.InDef(names)
+ scopePush()
+ for _, name in ipairs(names) do
+ scopeSet(name)
+ end
+end
+
+function defs.In()
+ scopePop()
+end
+
+function defs.WhileDef()
+ scopePush()
+end
+
+function defs.While()
+ scopePop()
+end
+
+function defs.RepeatDef()
+ scopePush()
+end
+
+function defs.Until()
+ scopePop()
+end
+
+return function (buf, pos_)
+ pos = pos_
+ result = nil
+ scopeInit()
+
+ local suc, err = parser.grammar(buf, 'Lua', defs)
+ if not suc then
+ return false, '语法错误', err
+ end
+
+ if not result then
+ return false, 'No word'
+ end
+ local name, start, finish = result[1], result[2], result[3]
+ if not start then
+ return false, 'No match'
+ end
+ if not finish then
+ finish = start + #name - 1
+ end
+ return true, start, finish
+end
diff --git a/src/matcher/init.lua b/src/matcher/init.lua
index 8e2a01e2..c570b342 100644
--- a/src/matcher/init.lua
+++ b/src/matcher/init.lua
@@ -1,5 +1,6 @@
local api = {
- definition = require 'matcher.definition',
+ definition = require 'matcher.definition',
+ implementation = require 'matcher.implementation',
}
return api
diff --git a/src/method/init.lua b/src/method/init.lua
index 214997d7..46b79cc4 100644
--- a/src/method/init.lua
+++ b/src/method/init.lua
@@ -7,6 +7,7 @@ end
init 'initialize'
init 'initialized'
init 'shutdown'
+init 'textDocument/implementation'
init 'textDocument/definition'
init 'textDocument/didOpen'
init 'textDocument/didChange'
diff --git a/src/method/initialize.lua b/src/method/initialize.lua
index ea5bb3c4..866ded66 100644
--- a/src/method/initialize.lua
+++ b/src/method/initialize.lua
@@ -5,7 +5,7 @@ return function (lsp, data)
-- 支持“转到定义”
definitionProvider = true,
-- 支持“转到实现”
- --implementationProvider = true,
+ implementationProvider = true,
-- 文本同步方式
textDocumentSync = {
-- 打开关闭文本时通知
diff --git a/src/method/textDocument/implementation.lua b/src/method/textDocument/implementation.lua
new file mode 100644
index 00000000..95bd0c20
--- /dev/null
+++ b/src/method/textDocument/implementation.lua
@@ -0,0 +1,46 @@
+local parser = require 'parser'
+local matcher = require 'matcher'
+
+return function (lsp, params)
+ local uri = params.textDocument.uri
+ local text = lsp:loadText(uri)
+ if not text then
+ return nil, '找不到文件:' .. uri
+ end
+ local start_clock = os.clock()
+ -- lua是从1开始的,因此都要+1
+ local pos = parser.calcline.position_utf8(text, params.position.line + 1, params.position.character + 1)
+ local suc, start, finish = matcher.implementation(text, pos)
+ if not suc then
+ if finish then
+ log.debug(start, uri)
+ finish.lua = nil
+ log.debug(table.dump(finish))
+ end
+ return {}
+ end
+
+ local start_row, start_col = parser.calcline.rowcol_utf8(text, start)
+ local finish_row, finish_col = parser.calcline.rowcol_utf8(text, finish)
+
+ local response = {
+ uri = uri,
+ range = {
+ start = {
+ line = start_row - 1,
+ character = start_col - 1,
+ },
+ ['end'] = {
+ line = finish_row - 1,
+ -- 这里不用-1,因为前端期待的是匹配完成后的位置
+ character = finish_col,
+ },
+ },
+ }
+ local passed_clock = os.clock() - start_clock
+ if passed_clock >= 0.01 then
+ log.warn(('[转到实现]耗时[%.3f]秒,文件大小[%s]字节'):format(passed_clock, #text))
+ end
+
+ return response
+end
diff --git a/test/definition/set.lua b/test/definition/set.lua
index 4cbb5926..2e48e490 100644
--- a/test/definition/set.lua
+++ b/test/definition/set.lua
@@ -2,3 +2,29 @@ TEST [[
<!x!> = 1
<?x?> = 1
]]
+
+TEST [[
+do
+ <!global!> = 1
+end
+<?global?> = 1
+]]
+
+TEST [[
+<!x!> = 1
+do
+ local x = 1
+end
+<?x?> = 1
+]]
+
+TEST [[
+x = 1
+do
+ local <!x!> = 1
+ do
+ x = 2
+ end
+ <?x?> = 1
+end
+]]
diff --git a/test/implementation/arg.lua b/test/implementation/arg.lua
new file mode 100644
index 00000000..2004d666
--- /dev/null
+++ b/test/implementation/arg.lua
@@ -0,0 +1,23 @@
+TEST [[
+local function xx (<!xx!>)
+ <?xx?> = 1
+end
+]]
+
+TEST [[
+local function x (x, <!...!>)
+ x = <?...?>
+end
+]]
+
+TEST [[
+function mt<!:!>x()
+ <?self?> = 1
+end
+]]
+
+TEST [[
+function mt:x(<!self!>)
+ <?self?> = 1
+end
+]]
diff --git a/test/implementation/bug.lua b/test/implementation/bug.lua
new file mode 100644
index 00000000..b0e890ca
--- /dev/null
+++ b/test/implementation/bug.lua
@@ -0,0 +1,15 @@
+TEST [[
+local <!x!>
+function _(x)
+end
+function _()
+ <?x?>
+end
+]]
+
+TEST [[
+function _(<!x!>)
+ do return end
+ <?x?> = 1
+end
+]]
diff --git a/test/implementation/function.lua b/test/implementation/function.lua
new file mode 100644
index 00000000..90b75da8
--- /dev/null
+++ b/test/implementation/function.lua
@@ -0,0 +1,24 @@
+
+TEST [[
+function <!x!> () end
+<?x?> = 1
+]]
+
+TEST [[
+local function <!x!> () end
+<?x?> = 1
+]]
+
+TEST [[
+local x
+local function <!x!> ()
+ <?x?> = 1
+end
+]]
+
+TEST [[
+local x
+function <!x!>()
+end
+<?x?> = 1
+]]
diff --git a/test/implementation/init.lua b/test/implementation/init.lua
new file mode 100644
index 00000000..3033f340
--- /dev/null
+++ b/test/implementation/init.lua
@@ -0,0 +1,22 @@
+local matcher = require 'matcher'
+
+rawset(_G, 'TEST', true)
+
+function TEST(script)
+ local start = script:find('<!', 1, true) + 2
+ local finish = script:find('!>', 1, true) - 1
+ local pos = script:find('<?', 1, true) + 2
+ local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ')
+
+ local suc, a, b = matcher.implementation(new_script, pos)
+ assert(suc)
+ assert(a == start)
+ assert(b == finish)
+end
+
+require 'implementation.set'
+require 'implementation.local'
+require 'implementation.arg'
+require 'implementation.function'
+--require 'implementation.table'
+require 'implementation.bug'
diff --git a/test/implementation/local.lua b/test/implementation/local.lua
new file mode 100644
index 00000000..0737443d
--- /dev/null
+++ b/test/implementation/local.lua
@@ -0,0 +1,191 @@
+TEST [[
+local <!x!>
+<?x?> = 1
+]]
+
+TEST [[
+local z, y, <!x!>
+<?x?> = 1
+]]
+
+TEST [[
+local <!x!> = 1
+<?x?> = 1
+]]
+
+TEST [[
+local z, y, <!x!> = 1
+<?x?> = 1
+]]
+
+TEST [[
+local x
+local <!x!>
+<?x?> = 1
+]]
+
+TEST [[
+local <!x!>
+do
+ <?x?> = 1
+end
+]]
+
+TEST [[
+local <!x!>
+do
+ local x
+end
+<?x?> = 1
+]]
+
+TEST [[
+local <!x!>
+if <?x?> then
+ local x
+end
+]]
+
+TEST [[
+local <!x!>
+if x then
+ local x
+elseif <?x?> then
+ local x
+end
+]]
+
+TEST [[
+local <!x!>
+if x then
+ local x
+elseif x then
+ local x
+else
+ local x
+end
+<?x?> = 1
+]]
+
+TEST [[
+local <!x!>
+if x then
+ <?x?> = 1
+elseif x then
+ local x
+else
+ local x
+end
+]]
+
+TEST [[
+local <!x!>
+for x = 1, 10 do
+end
+<?x?> = 1
+]]
+
+TEST [[
+local x
+for <!x!> = 1, 10 do
+ <?x?> = 1
+end
+]]
+
+TEST [[
+local <!x!>
+for x in x do
+end
+<?x?> = 1
+]]
+
+TEST [[
+local <!x!>
+for x in <?x?> do
+end
+]]
+
+TEST [[
+local x
+for <!x!> in x do
+ <?x?> = 1
+end
+]]
+
+TEST [[
+local x
+for z, y, <!x!> in x do
+ <?x?> = 1
+end
+]]
+
+TEST [[
+local <!x!>
+while <?x?> do
+end
+]]
+
+TEST [[
+local <!x!>
+while x do
+ <?x?> = 1
+end
+]]
+
+TEST [[
+local <!x!>
+while x do
+ local x
+end
+<?x?> = 1
+]]
+
+TEST [[
+local <!x!>
+repeat
+ <?x?> = 1
+until true
+]]
+
+TEST [[
+local <!x!>
+repeat
+ local x
+until true
+<?x?> = 1
+]]
+
+TEST [[
+local <!x!>
+repeat
+until <?x?>
+]]
+
+TEST [[
+local x
+repeat
+ local <!x!>
+until <?x?>
+]]
+
+TEST [[
+local <!x!>
+function _()
+ local x
+end
+<?x?> = 1
+]]
+
+TEST [[
+local <!x!>
+return function ()
+ <?x?> = 1
+end
+]]
+
+TEST [[
+local <!x!>
+local x = function ()
+ <?x?> = 1
+end
+]]
diff --git a/test/implementation/set.lua b/test/implementation/set.lua
new file mode 100644
index 00000000..5c4a1a2e
--- /dev/null
+++ b/test/implementation/set.lua
@@ -0,0 +1,31 @@
+TEST [[
+<!x!> = 1
+<?x?> = 1
+]]
+
+TEST [[
+global = 1
+do
+ <!global!> = 2
+end
+<?global?> = 3
+]]
+
+TEST [[
+<!x!> = 1
+do
+ local x = 1
+end
+<?x?> = 1
+]]
+
+TEST [[
+x = 1
+do
+ local x = 1
+ do
+ <!x!> = 2
+ end
+ <?x?> = 1
+end
+]]
diff --git a/test/implementation/table.lua b/test/implementation/table.lua
new file mode 100644
index 00000000..13a3b555
--- /dev/null
+++ b/test/implementation/table.lua
@@ -0,0 +1,6 @@
+TEST [[
+local t = {
+ <!x!> = 1,
+}
+t.<?x?> = 1
+]]
diff --git a/test/main.lua b/test/main.lua
index 972a1068..c33526c6 100644
--- a/test/main.lua
+++ b/test/main.lua
@@ -21,6 +21,7 @@ local function main()
end
test 'definition'
+ test 'implementation'
print('测试完成')
end