From 7ade261608ef3649be5c6ee2961e926527e2f03d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Mon, 18 Nov 2019 01:15:20 +0800 Subject: =?UTF-8?q?=E6=9A=82=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server-beta/src/parser/guide.lua | 19 ++- server-beta/src/searcher/init.lua | 2 + server-beta/src/searcher/isTrue.lua | 52 ++++++ server-beta/src/searcher/searcher.lua | 4 + server-beta/src/searcher/typeInference.lua | 253 +++++++++++++++++++++++++++++ server-beta/test.lua | 3 +- server-beta/test/find_lib/init.lua | 106 ------------ server-beta/test/rename/init.lua | 35 +--- server-beta/test/type_inference/init.lua | 62 ++++--- 9 files changed, 367 insertions(+), 169 deletions(-) create mode 100644 server-beta/src/searcher/isTrue.lua create mode 100644 server-beta/src/searcher/typeInference.lua delete mode 100644 server-beta/test/find_lib/init.lua diff --git a/server-beta/src/parser/guide.lua b/server-beta/src/parser/guide.lua index 80f9d160..f91c677e 100644 --- a/server-beta/src/parser/guide.lua +++ b/server-beta/src/parser/guide.lua @@ -299,7 +299,10 @@ function m.eachSourceContain(ast, offset, callback) list[len] = nil if m.isInRange(obj, offset) then if m.isContain(obj, offset) then - callback(obj) + local res = callback(obj) + if res ~= nil then + return res + end end m.addChilds(list, obj, m.childMap) end @@ -476,19 +479,23 @@ function m.getKeyName(obj) elseif tp == 'getfield' or tp == 'setfield' or tp == 'tablefield' then - return 's|' .. obj.field[1] + if obj.field then + return 's|' .. obj.field[1] + end elseif tp == 'getmethod' or tp == 'setmethod' then - return 's|' .. obj.method[1] + if obj.method then + return 's|' .. obj.method[1] + end elseif tp == 'getindex' or tp == 'setindex' or tp == 'tableindex' then - return m.getKeyName(obj.index) + if obj.index then + return m.getKeyName(obj.index) + end elseif tp == 'field' or tp == 'method' then return 's|' .. obj[1] - elseif tp == 'index' then - return m.getKeyName(obj.index) elseif tp == 'string' then local s = obj[1] if s then diff --git a/server-beta/src/searcher/init.lua b/server-beta/src/searcher/init.lua index 6ec7bc99..1734c2ea 100644 --- a/server-beta/src/searcher/init.lua +++ b/server-beta/src/searcher/init.lua @@ -6,4 +6,6 @@ require 'searcher.getGlobals' require 'searcher.getLinks' require 'searcher.getGlobal' require 'searcher.getLibrary' +require 'searcher.typeInference' +require 'searcher.isTrue' return searcher diff --git a/server-beta/src/searcher/isTrue.lua b/server-beta/src/searcher/isTrue.lua new file mode 100644 index 00000000..ba825d43 --- /dev/null +++ b/server-beta/src/searcher/isTrue.lua @@ -0,0 +1,52 @@ +local searcher = require 'searcher.searcher' + +local function checkLiteral(source) + if source.type == 'boolean' then + if source[1] == true then + return 'true' + else + return 'false' + end + end + if source.type == 'string' then + return 'true' + end + if source.type == 'number' then + return 'true' + end + if source.type == 'table' then + return 'true' + end + if source.type == 'function' then + return 'true' + end + if source.type == 'nil' then + return 'false' + end +end + +local function isTrue(source) + local res = checkLiteral(source) + if res then + return res + end + return 'unknown' +end + +function searcher.isTrue(source) + if not source then + return + end + local cache = searcher.cache.isTrue[source] + if cache ~= nil then + return cache + end + local unlock = searcher.lock('isTrue', source) + if not unlock then + return + end + cache = isTrue(source) or false + searcher.cache.isTrue[source] = cache + unlock() + return cache +end diff --git a/server-beta/src/searcher/searcher.lua b/server-beta/src/searcher/searcher.lua index 19e114d6..3e12430b 100644 --- a/server-beta/src/searcher/searcher.lua +++ b/server-beta/src/searcher/searcher.lua @@ -64,6 +64,8 @@ function m.refreshCache() getGlobal = {}, specialName = {}, getLibrary = {}, + typeInfer = {}, + isTrue = {}, specials = nil, } m.locked = { @@ -72,6 +74,8 @@ function m.refreshCache() getGlobals = {}, getLinks = {}, getLibrary = {}, + typeInfer = {}, + isTrue = {}, } m.cacheTracker[m.cache] = true end diff --git a/server-beta/src/searcher/typeInference.lua b/server-beta/src/searcher/typeInference.lua new file mode 100644 index 00000000..56f4cc98 --- /dev/null +++ b/server-beta/src/searcher/typeInference.lua @@ -0,0 +1,253 @@ +local searcher = require 'searcher.searcher' +local guide = require 'parser.guide' +local config = require 'config' + +local typeSort = { + ['boolean'] = 1, + ['string'] = 2, + ['integer'] = 3, + ['number'] = 4, + ['table'] = 5, + ['function'] = 6, + ['nil'] = math.maxinteger, +} + +local function merge(result, tp) + if result[tp] then + return + end + if tp:find('|', 1, true) then + for sub in tp:gmatch '[^|]+' do + if not result[sub] then + result[sub] = true + result[#result+1] = sub + end + end + else + result[tp] = true + result[#result+1] = tp + end +end + +local function hasType(tp, target) + if not target then + return false + end + for sub in target:gmatch '[^|]+' do + if sub == tp then + return true + end + end + return false +end + +local function dump(result) + if #result <= 1 then + return result[1] + end + table.sort(result, function (a, b) + local sa = typeSort[a] + local sb = typeSort[b] + if sa and sb then + return sa < sb + end + if not sa and not sb then + return a < b + end + if sa and not sb then + return true + end + if not sa and sb then + return false + end + return false + end) + return table.concat(result, '|') +end + +local function checkLiteral(source) + if source.type == 'string' then + return 'string' + elseif source.type == 'nil' then + return 'nil' + elseif source.type == 'boolean' then + return 'boolean' + elseif source.type == 'number' then + if math.type(source[1]) == 'integer' then + return 'integer' + else + return 'number' + end + elseif source.type == 'table' then + return 'table' + elseif source.type == 'function' then + return 'function' + end +end + +local function checkUnary(source) + if source.type ~= 'unary' then + return + end + local op = source.op + if op.type == 'not' then + return 'boolean' + elseif op.type == '#' + or op.type == '~' then + return 'integer' + elseif op.type == '-' then + if hasType('integer', searcher.typeInference(source[1])) then + return 'integer' + else + return 'number' + end + end +end + +local function checkBinary(source) + if source.type ~= 'binary' then + return + end + local op = source.op + if op.type == 'and' then + local isTrue = searcher.isTrue(source[1]) + if isTrue == 'true' then + return searcher.typeInference(source[2]) + elseif isTrue == 'false' then + return searcher.typeInference(source[1]) + else + local result = {} + merge(result, searcher.typeInference(source[1])) + merge(result, searcher.typeInference(source[2])) + return dump(result) + end + end + if op.type == 'or' then + local isTrue = searcher.isTrue(source[1]) + if isTrue == 'true' then + return searcher.typeInference(source[1]) + elseif isTrue == 'false' then + return searcher.typeInference(source[2]) + else + local result = {} + merge(result, searcher.typeInference(source[1])) + merge(result, searcher.typeInference(source[2])) + return dump(result) + end + end + if op.type == '==' + or op.type == '~=' + or op.type == '<=' + or op.type == '>=' + or op.type == '<' + or op.type == '>' then + return 'boolean' + end + if op.type == '|' + or op.type == '~' + or op.type == '&' + or op.type == '<<' + or op.type == '>>' then + return 'integer' + end + if op.type == '..' then + return 'string' + end + if op.type == '^' + or op.type == '/' then + return 'number' + end + -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数 + if op.type == '+' + or op.type == '-' + or op.type == '*' + or op.type == '%' + or op.type == '//' then + if hasType('integer', searcher.typeInference(source[1])) + and hasType('integer', searcher.typeInference(source[2])) then + return 'integer' + else + return 'number' + end + end +end + +local function checkValue(source) + if source.value then + return searcher.typeInference(source.value) + end +end + +local function checkCall(result, source) + if not source.parent then + return + end + if source.parent.type ~= 'call' then + return + end + if source.parent.node == source then + merge(result, 'function') + return + end +end + +local function checkNext(result, source) + local next = source.next + if not next then + return + end + if next.type == 'getfield' + or next.type == 'getindex' + or next.type == 'getmethod' + or next.type == 'setfield' + or next.type == 'setindex' + or next.type == 'setmethod' then + merge(result, 'table') + end +end + +local function checkDef(result, source) + searcher.eachDef(source, function (info) + local src = info.source + local tp = searcher.typeInference(src) + if tp then + merge(result, tp) + end + end) +end + +local function typeInference(source) + local tp = checkLiteral(source) + or checkValue(source) + or checkUnary(source) + or checkBinary(source) + if tp then + return tp + end + + local result = {} + + checkCall(result, source) + checkNext(result, source) + checkDef(result, source) + + return dump(result) +end + +function searcher.typeInference(source) + if not source then + return + end + local cache = searcher.cache.typeInfer[source] + if cache ~= nil then + return cache + end + local unlock = searcher.lock('typeInfer', source) + if not unlock then + return + end + cache = typeInference(source) or false + searcher.cache.typeInfer[source] = cache + unlock() + return cache +end diff --git a/server-beta/test.lua b/server-beta/test.lua index 7213d596..dffae940 100644 --- a/server-beta/test.lua +++ b/server-beta/test.lua @@ -41,8 +41,7 @@ local function main() test 'diagnostics' test 'highlight' test 'rename' - --test 'type_inference' - --test 'find_lib' + test 'type_inference' --test 'hover' --test 'completion' --test 'signature' diff --git a/server-beta/test/find_lib/init.lua b/server-beta/test/find_lib/init.lua deleted file mode 100644 index 51f53b75..00000000 --- a/server-beta/test/find_lib/init.lua +++ /dev/null @@ -1,106 +0,0 @@ -local core = require 'core' -local parser = require 'parser' -local buildVM = require 'vm' - -rawset(_G, 'TEST', true) - -function TEST(fullkey) - return function (script) - local start = script:find('', 1, true) - local pos = (start + finish) // 2 + 1 - local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ') - local ast = parser:parse(new_script, 'lua', 'Lua 5.3') - assert(ast) - local vm = buildVM(ast) - assert(vm) - local source = core.findSource(vm, pos) - local _, name = core.findLib(source) - assert(name == fullkey) - end -end - -TEST 'require' [[ - 'xxx' -]] - -TEST 'req' [[ -local = require -]] - -TEST 'req' [[ -local req = require -local t = { - xx = req, -} -t[]() -]] - -TEST 'table' [[ -.unpack() -]] - -TEST 'xx' [[ -local = require 'table' -]] - -TEST 'xx
' [[ -local rq = require -local lib = 'table' -local = rq(lib) -]] - -TEST 'table.insert' [[ -table.() -]] - -TEST 'table.insert' [[ -local t = table -t.() -]] - -TEST 'table.insert' [[ -local insert = table.insert -() -]] - -TEST 'table.insert' [[ -local t = require 'table' -t.() -]] - -TEST 'table.insert' [[ -require 'table'.() -]] - -TEST '*string.sub' [[ -local str = 'xxx' -str.() -]] - -TEST '*string:sub' [[ -local str = 'xxx' -str:(1, 1) -]] - -TEST '*string.sub' [[ -('xxx').() -]] - -TEST 'fs' [[ -local = require 'bee.filesystem' -]] - -TEST 'fs.current_path' [[ -local filesystem = require 'bee.filesystem' - -ROOT = filesystem.() -]] - -TEST(nil)[[ -print() -]] - -TEST '_G' [[ -local x = -]] diff --git a/server-beta/test/rename/init.lua b/server-beta/test/rename/init.lua index e9fc5aee..8e672e69 100644 --- a/server-beta/test/rename/init.lua +++ b/server-beta/test/rename/init.lua @@ -1,39 +1,6 @@ -local core = require 'core.rename' +local core = require 'core.rename' local files = require 'files' -local function catch_target(script) - local list = {} - local cur = 1 - while true do - local start, finish = script:find('<[!?].-[!?]>', cur) - if not start then - break - end - list[#list+1] = { - start = start + 2, - finish = finish - 2, - } - cur = finish + 1 - end - return list -end - -local function founded(targets, results) - if #targets ~= #results then - return false - end - for _, target in ipairs(targets) do - for _, result in ipairs(results) do - if target.start == result.start and target.finish == result.finish then - goto NEXT - end - end - do return false end - ::NEXT:: - end - return true -end - local function replace(text, positions) local buf = {} table.sort(positions, function (a, b) diff --git a/server-beta/test/type_inference/init.lua b/server-beta/test/type_inference/init.lua index 59d853eb..4a9f30aa 100644 --- a/server-beta/test/type_inference/init.lua +++ b/server-beta/test/type_inference/init.lua @@ -1,22 +1,37 @@ -local parser = require 'parser' -local core = require 'core' -local buildVM = require 'vm' -local config = require 'config' +local files = require 'files' +local config = require 'config' +local searcher = require 'searcher' +local guide = require 'parser.guide' rawset(_G, 'TEST', true) -function TEST(res) +local function getSource(pos) + local ast = files.getAst('') + return guide.eachSourceContain(ast.ast, pos, function (source) + if source.type == 'local' + or source.type == 'getlocal' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'getglobal' + or source.type == 'field' + or source.type == 'method' then + return source + end + end) +end + +function TEST(wanted) return function (script) + files.removeAll() local start = script:find('', 1, true) local pos = (start + finish) // 2 + 1 - local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ') - local ast = parser:parse(new_script, 'lua', 'Lua 5.3') - local vm = buildVM(ast) - assert(vm) - local result = core.findSource(vm, pos) - assert(result) - assert(res == result:bindValue():getType()) + local newScript = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ') + files.setText('', newScript) + local source = getSource(pos) + assert(source) + local result = searcher.typeInference(source) or 'any' + assert(wanted == result) end end @@ -30,20 +45,30 @@ TEST 'boolean' [[ local = true ]] -TEST 'number' [[ +TEST 'integer' [[ local = 1 ]] +TEST 'number' [[ +local = 1.0 +]] + TEST 'string' [[ local var = '111' t. = var ]] -TEST 'string' [[ +TEST 'any' [[ local var = '111' ]] +TEST 'string' [[ +local var +var = '111' +print() +]] + TEST 'function' [[ function () end @@ -55,8 +80,8 @@ end ]] TEST 'function' [[ -local -xx = function () +local xx + = function () end ]] @@ -64,11 +89,6 @@ TEST 'table' [[ local = {} ]] -TEST 'table' [[ -local -t = {} -]] - TEST 'function' [[ () ]] -- cgit v1.2.3