summaryrefslogtreecommitdiff
path: root/script-beta/src/core/highlight.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script-beta/src/core/highlight.lua')
-rw-r--r--script-beta/src/core/highlight.lua230
1 files changed, 230 insertions, 0 deletions
diff --git a/script-beta/src/core/highlight.lua b/script-beta/src/core/highlight.lua
new file mode 100644
index 00000000..61e3f91a
--- /dev/null
+++ b/script-beta/src/core/highlight.lua
@@ -0,0 +1,230 @@
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm'
+local define = require 'proto.define'
+
+local function ofLocal(source, callback)
+ callback(source)
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback(ref)
+ end
+ end
+end
+
+local function ofField(source, uri, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ local myKey = guide.getKeyName(source)
+ if parent.type == 'tableindex'
+ or parent.type == 'tablefield' then
+ local tbl = parent.parent
+ vm.eachField(tbl, function (info)
+ if info.key ~= myKey then
+ return
+ end
+ local destUri = guide.getRoot(info.source).uri
+ if destUri ~= uri then
+ return
+ end
+ callback(info.source)
+ end)
+ else
+ vm.eachField(parent.node, function (info)
+ if info.key ~= myKey then
+ return
+ end
+ local destUri = guide.getRoot(info.source).uri
+ if destUri ~= uri then
+ return
+ end
+ callback(info.source)
+ end)
+ end
+end
+
+local function ofIndex(source, uri, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ ofField(source, uri, callback)
+ end
+end
+
+local function ofLabel(source, callback)
+ vm.eachRef(source, function (info)
+ callback(info.source)
+ end)
+end
+
+local function find(source, uri, callback)
+ if source.type == 'local' then
+ ofLocal(source, callback)
+ elseif source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ ofLocal(source.node, callback)
+ elseif source.type == 'field'
+ or source.type == 'method' then
+ ofField(source, uri, callback)
+ elseif source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number' then
+ ofIndex(source, uri, callback)
+ callback(source)
+ elseif source.type == 'nil' then
+ callback(source)
+ elseif source.type == 'goto'
+ or source.type == 'label' then
+ ofLabel(source, callback)
+ end
+end
+
+local function checkInIf(source, text, offset)
+ -- 检查 end
+ local endA = source.finish - #'end' + 1
+ local endB = source.finish
+ if offset >= endA
+ and offset <= endB
+ and text:sub(endA, endB) == 'end' then
+ return true
+ end
+ -- 检查每个子模块
+ for _, block in ipairs(source) do
+ for i = 1, #block.keyword, 2 do
+ local start = block.keyword[i]
+ local finish = block.keyword[i+1]
+ if offset >= start and offset <= finish then
+ return true
+ end
+ end
+ end
+ return false
+end
+
+local function makeIf(source, text, callback)
+ -- end
+ local endA = source.finish - #'end' + 1
+ local endB = source.finish
+ if text:sub(endA, endB) == 'end' then
+ callback(endA, endB)
+ end
+ -- 每个子模块
+ for _, block in ipairs(source) do
+ for i = 1, #block.keyword, 2 do
+ local start = block.keyword[i]
+ local finish = block.keyword[i+1]
+ callback(start, finish)
+ end
+ end
+ return false
+end
+
+local function findKeyword(source, text, offset, callback)
+ if source.type == 'do'
+ or source.type == 'function'
+ or source.type == 'loop'
+ or source.type == 'in'
+ or source.type == 'while'
+ or source.type == 'repeat' then
+ local ok
+ for i = 1, #source.keyword, 2 do
+ local start = source.keyword[i]
+ local finish = source.keyword[i+1]
+ if offset >= start and offset <= finish then
+ ok = true
+ break
+ end
+ end
+ if ok then
+ for i = 1, #source.keyword, 2 do
+ local start = source.keyword[i]
+ local finish = source.keyword[i+1]
+ callback(start, finish)
+ end
+ end
+ elseif source.type == 'if' then
+ local ok = checkInIf(source, text, offset)
+ if ok then
+ makeIf(source, text, callback)
+ end
+ end
+end
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local text = files.getText(uri)
+ local results = {}
+ local mark = {}
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ find(source, uri, function (target)
+ local kind
+ if target.type == 'getfield' then
+ target = target.field
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setfield'
+ or target.type == 'tablefield' then
+ target = target.field
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getmethod' then
+ target = target.method
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setmethod' then
+ target = target.method
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getindex' then
+ target = target.index
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setindex'
+ or target.type == 'tableindex' then
+ target = target.index
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getlocal'
+ or target.type == 'getglobal'
+ or target.type == 'goto' then
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setlocal'
+ or target.type == 'local'
+ or target.type == 'setglobal'
+ or target.type == 'label' then
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'string'
+ or target.type == 'boolean'
+ or target.type == 'number'
+ or target.type == 'nil' then
+ kind = define.DocumentHighlightKind.Text
+ else
+ log.warn('Unknow target.type:', target.type)
+ return
+ end
+ if mark[target] then
+ return
+ end
+ mark[target] = true
+ results[#results+1] = {
+ start = target.start,
+ finish = target.finish,
+ kind = kind,
+ }
+ end)
+ findKeyword(source, text, offset, function (start, finish)
+ results[#results+1] = {
+ start = start,
+ finish = finish,
+ kind = define.DocumentHighlightKind.Write
+ }
+ end)
+ end)
+ if #results == 0 then
+ return nil
+ end
+ return results
+end