diff options
-rw-r--r-- | script-beta/core/diagnostics/redundant-parameter.lua | 2 | ||||
-rw-r--r-- | script-beta/core/hover/init.lua | 2 | ||||
-rw-r--r-- | script-beta/core/hover/label.lua | 18 | ||||
-rw-r--r-- | script-beta/core/hover/return.lua | 2 | ||||
-rw-r--r-- | script-beta/core/hover/table.lua | 3 | ||||
-rw-r--r-- | script-beta/parser/guide.lua | 175 | ||||
-rw-r--r-- | script-beta/vm/getClass.lua | 2 | ||||
-rw-r--r-- | script-beta/vm/getInfer.lua (renamed from script-beta/vm/getValue.lua) | 65 | ||||
-rw-r--r-- | script-beta/vm/getLibrary.lua | 2 | ||||
-rw-r--r-- | script-beta/vm/guideInterface.lua | 2 | ||||
-rw-r--r-- | script-beta/vm/init.lua | 2 | ||||
-rw-r--r-- | test-beta/type_inference/init.lua | 2 |
12 files changed, 135 insertions, 142 deletions
diff --git a/script-beta/core/diagnostics/redundant-parameter.lua b/script-beta/core/diagnostics/redundant-parameter.lua index d619ba10..b424c2bf 100644 --- a/script-beta/core/diagnostics/redundant-parameter.lua +++ b/script-beta/core/diagnostics/redundant-parameter.lua @@ -65,7 +65,7 @@ return function (uri, callback) if not vm.hasType(func, 'function') then return end - local values = vm.getValue(func) + local values = vm.getInfers(func) for _, value in ipairs(values) do if value.type and value.source.type == 'function' then local args = countFuncArgs(value.source) diff --git a/script-beta/core/hover/init.lua b/script-beta/core/hover/init.lua index 007e8aad..e123bb06 100644 --- a/script-beta/core/hover/init.lua +++ b/script-beta/core/hover/init.lua @@ -7,7 +7,7 @@ local util = require 'utility' local findSource = require 'core.find-source' local function getHoverAsFunction(source) - local values = vm.getValue(source) + local values = vm.getInfers(source) local desc = getDesc(source) local labels = {} local defs = 0 diff --git a/script-beta/core/hover/label.lua b/script-beta/core/hover/label.lua index 5e14c68e..a776f0c4 100644 --- a/script-beta/core/hover/label.lua +++ b/script-beta/core/hover/label.lua @@ -18,17 +18,19 @@ end local function asValue(source, title) local name = buildName(source) - local class, type, literal, cont - local values = vm.getValue(source) + local class = 'any' + local type = 'any' + local literal, cont + local values = vm.getInfers(source) if values then for _, value in ipairs(values) do local src = value.source local tp = value.type - class = guide.mergeInfers(class, vm.getClass(src)) - type = guide.mergeInfers(type, tp) + class = guide.mergeTypes {class, vm.getClass(src)} + type = guide.mergeTypes {type, tp} local sl = vm.getLiteral(src) if sl then - literal = guide.mergeInfers(literal, util.viewLiteral(sl)) + literal = guide.mergeTypes {literal, util.viewLiteral(sl)} end if tp == 'table' then cont = buildTable(src) @@ -36,11 +38,11 @@ local function asValue(source, title) end end vm.eachDef(source, function (src) - class = guide.mergeInfers(class, vm.getClass(src)) - type = guide.mergeInfers(type, vm.getType(src)) + class = guide.mergeTypes {class, vm.getClass(src)} + type = guide.mergeTypes {type, vm.getType(src)} local sl = vm.getLiteral(src) if sl then - literal = guide.mergeInfers(literal, util.viewLiteral(sl)) + literal = guide.mergeTypes {literal, util.viewLiteral(sl)} end if type == 'table' then cont = buildTable(src) diff --git a/script-beta/core/hover/return.lua b/script-beta/core/hover/return.lua index 82dc1314..f67a961f 100644 --- a/script-beta/core/hover/return.lua +++ b/script-beta/core/hover/return.lua @@ -36,7 +36,7 @@ local function asFunction(source) local returns = {} for _, rtn in ipairs(source.returns) do for i = 1, #rtn do - local values = vm.getValue(rtn[i]) + local values = vm.getInfers(rtn[i]) returns[#returns+1] = values end break diff --git a/script-beta/core/hover/table.lua b/script-beta/core/hover/table.lua index d00440a7..5b086cd6 100644 --- a/script-beta/core/hover/table.lua +++ b/script-beta/core/hover/table.lua @@ -1,5 +1,6 @@ local vm = require 'vm' local util = require 'utility' +local guide = require 'parser.guide' local function getKey(src) local key = vm.getKeyName(src) @@ -118,7 +119,7 @@ return function (source) local intValue = true vm.eachField(source, function (src) local key, class, literal = getField(src) - classes[key] = guide.mergeInfers(class, classes[key]) + classes[key] = guide.mergeTypes {class, classes[key]} literals[key] = mergeLiteral(literal, literals[key]) if class ~= 'integer' or not literals[key] diff --git a/script-beta/parser/guide.lua b/script-beta/parser/guide.lua index 228b0a61..4f15a675 100644 --- a/script-beta/parser/guide.lua +++ b/script-beta/parser/guide.lua @@ -1423,8 +1423,6 @@ function m.checkSameSimpleAsReturn(status, ref, start, queue) if ref.parent.type ~= 'return' then return end - -- TODO 这里的开销非常大 - --do return end if ref.parent.parent.type ~= 'main' then return end @@ -1786,7 +1784,7 @@ function m.cleanResults(results) end end -function m.getCache(status, obj, mode) +function m.getRefCache(status, obj, mode) if not status.interface.cache then return end @@ -1802,7 +1800,7 @@ end function m.searchRefs(status, obj, mode) status.depth = status.depth + 1 - local cache, makeCache = m.getCache(status, obj, mode) + local cache, makeCache = m.getRefCache(status, obj, mode) if cache then for i = 1, #cache do status.results[#status.results+1] = cache[i] @@ -1844,23 +1842,6 @@ function m.searchRefOfValue(status, obj) end end -function m.mergeInfer(t, b) - if not t then - t = {} - end - if not b then - return t - end - for i = 1, #b do - local o = b[i] - if not t[o] then - t[o] = true - t[#t+1] = o - end - end - return t -end - function m.allocInfer(o) -- TODO assert(o.type) @@ -1873,31 +1854,26 @@ function m.allocInfer(o) source = o.source, } values[i] = sub - values[sub] = true end return values else return { [1] = o, - [o] = true, } end end -function m.insertInfer(t, o) - if not o then - return - end - if not t[o] then - t[o] = true - t[#t+1] = o +function m.mergeTypes(infers) + local types = {} + for i = 1, #infers do + for tp in infers[i]:gmatch '[^|]+' do + if not types[tp] and tp ~= 'any' then + types[#types+1] = tp + end + end end - return t -end - -local function mergeInfers(types) if #types == 0 then - return nil + return 'any' end if #types == 1 then return types[1] @@ -1922,23 +1898,6 @@ local function mergeInfers(types) return tableConcat(types, '|') end -function m.mergeInfers(...) - local max = select('#', ...) - local views = {} - for i = 1, max do - local view = select(i, ...) - if view then - for tp in view:gmatch '[^|]+' do - if not views[tp] and tp ~= 'any' then - views[tp] = true - views[#views+1] = tp - end - end - end - end - return mergeInfers(views) -end - function m.viewInfer(infers) if not infers then return 'any' @@ -1946,102 +1905,142 @@ function m.viewInfer(infers) if type(infers) ~= 'table' then return infers or 'any' end + local mark = {} local types = {} for i = 1, #infers do local tp = infers[i].type - if tp and not types[tp] and tp ~= 'any' then - types[tp] = true + if not mark[tp] and tp ~= 'any' then types[#types+1] = tp end + mark[tp] = true end - return m.mergeInfers(types) or 'any' + return m.mergeTypes(types) end function m.inferCheckLiteral(status, source) if source.type == 'string' then - return m.alloc { + return m.allocInfer { type = 'string', value = source[1], source = source, } elseif source.type == 'nil' then - return m.alloc { + return m.allocInfer { type = 'nil', value = NIL, source = source, } elseif source.type == 'boolean' then - return m.alloc { + return m.allocInfer { type = 'boolean', value = source[1], source = source, } elseif source.type == 'number' then if mathType(source[1]) == 'integer' then - return m.alloc { + return m.allocInfer { type = 'integer', value = source[1], source = source, } else - return m.alloc { + return m.allocInfer { type = 'number', value = source[1], source = source, } end elseif source.type == 'integer' then - return m.alloc { + return m.allocInfer { type = 'integer', source = source, } elseif source.type == 'table' then - return m.alloc { + return m.allocInfer { type = 'table', source = source, } elseif source.type == 'function' then - return m.alloc { + return m.allocInfer { type = 'function', source = source, } elseif source.type == '...' then - return m.alloc { + return m.allocInfer { type = '...', source = source, } end end +function m.inferByDef(status, obj) + local newStatus = m.status(status) + m.searchRefs(newStatus, obj, 'def') + for _, src in ipairs(newStatus.results) do + local inferStatus = m.status(status) + local infers = m.searchInfer(inferStatus, src) + + end +end + +function m.cleanInfers(infers) + local mark = {} + for i = 1, #infers do + local source = infers[i].source + if mark[source] then + infers[i] = infers[#infers] + infers[#infers] = nil + else + mark[source] = true + end + end +end + function m.searchInfer(status, obj) obj = m.getObjectValue(obj) or obj + + local cache, makeCache + if status.interface.cache then + cache, makeCache = status.interface.cache(obj, 'infer') + end + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + local results = m.inferCheckLiteral(status, obj) - --or inferCheckUnary(obj) - --or inferCheckBinary(obj) - --or inferCheckLibraryTypes(obj) - --or inferCheckLibrary(obj) - --or inferCheckSpecialReturn(obj) - --or inferCheckLibraryReturn(obj) + --or m.inferCheckUnary(obj) + --or m.inferCheckBinary(obj) + --or m.inferCheckLibraryTypes(obj) + --or m.inferCheckLibrary(obj) + --or m.inferCheckSpecialReturn(obj) + --or m.inferCheckLibraryReturn(obj) if results then - return results + m.cleanInfers(results) + for i = 1, #results do + status.results[#status.results+1] = results[i] + end + if makeCache then + makeCache(status.results) + end + return end - results = {} - --inferByLibraryArg(results, obj) - --inferByDef(results, source) - --inferBySet(results, obj) - --inferByCall(results, obj) - --inferByGetTable(results, obj) - --inferByUnary(results, obj) - --inferByBinary(results, obj) - --inferByCallReturn(results, obj) - --inferByPCallReturn(results, obj) - - if #results == 0 then - return nil + --inferByLibraryArg(status, obj) + m.inferByDef(status, obj) + --m.inferBySet(status, obj) + --m.inferByCall(status, obj) + --m.inferByGetTable(status, obj) + --m.inferByUnary(status, obj) + --m.inferByBinary(status, obj) + --m.inferByCallReturn(status, obj) + --m.inferByPCallReturn(status, obj) + m.cleanInfers(status.results) + if makeCache then + makeCache(status.results) end - - return results end --- 请求对象的引用,包括 `a.b.c` 形式 @@ -2081,7 +2080,9 @@ end --- 请求对象的类型推测 function m.requestInfer(obj, interface) local status = m.status(nil, interface) - return m.searchInfer(status, obj) + m.searchInfer(status, obj) + + return status.results, status.cache.count end return m diff --git a/script-beta/vm/getClass.lua b/script-beta/vm/getClass.lua index 0d5d45b1..a3394e41 100644 --- a/script-beta/vm/getClass.lua +++ b/script-beta/vm/getClass.lua @@ -41,5 +41,5 @@ function vm.getClass(source) if #classes == 0 then return nil end - return guide.mergeInfers(table.unpack(classes)) + return guide.mergeTypes(classes) end diff --git a/script-beta/vm/getValue.lua b/script-beta/vm/getInfer.lua index bc83dc10..e98ac1a0 100644 --- a/script-beta/vm/getValue.lua +++ b/script-beta/vm/getInfer.lua @@ -189,25 +189,25 @@ local function checkBinary(source) if op.type == 'and' then local isTrue = vm.checkTrue(source[1]) if isTrue == true then - return vm.getValue(source[2]) + return vm.getInfers(source[2]) elseif isTrue == false then - return vm.getValue(source[1]) + return vm.getInfers(source[1]) else return merge( - vm.getValue(source[1]), - vm.getValue(source[2]) + vm.getInfers(source[1]), + vm.getInfers(source[2]) ) end elseif op.type == 'or' then local isTrue = vm.checkTrue(source[1]) if isTrue == true then - return vm.getValue(source[1]) + return vm.getInfers(source[1]) elseif isTrue == false then - return vm.getValue(source[2]) + return vm.getInfers(source[2]) else return merge( - vm.getValue(source[1]), - vm.getValue(source[2]) + vm.getInfers(source[1]), + vm.getInfers(source[2]) ) end elseif op.type == '==' then @@ -711,7 +711,7 @@ local function inferBySetOfLocal(results, source) if ref.type == 'setlocal' then break end - merge(results, vm.getValue(ref)) + merge(results, vm.getInfers(ref)) end end end @@ -724,7 +724,7 @@ local function inferBySet(results, source) inferBySetOfLocal(results, source) elseif source.type == 'setlocal' or source.type == 'getlocal' then - merge(results, vm.getValue(source.node)) + merge(results, vm.getInfers(source.node)) end end @@ -736,7 +736,7 @@ local function mergeFunctionReturns(results, source, index) for i = 1, #returns do local rtn = returns[i] if rtn[index] then - merge(results, vm.getValue(rtn[index])) + merge(results, vm.getInfers(rtn[index])) end end end @@ -749,7 +749,7 @@ local function inferByCallReturn(results, source) return end local node = source.vararg.node - local nodeValues = vm.getValue(node) + local nodeValues = vm.getInfers(node) if not nodeValues then return end @@ -786,7 +786,7 @@ local function inferByPCallReturn(results, source) else return end - local funcValues = vm.getValue(func) + local funcValues = vm.getInfers(func) if not funcValues then return end @@ -817,7 +817,7 @@ function vm.inferValue(source, infer) results = {} inferByLibraryArg(results, source) - --inferByDef(results, source) + inferByDef(results, source) inferBySet(results, source) inferByCall(results, source) inferByGetTable(results, source) @@ -834,7 +834,7 @@ function vm.inferValue(source, infer) end function vm.checkTrue(source) - local values = vm.getValue(source) + local values = vm.getInfers(source) if not values then return end @@ -869,7 +869,7 @@ end --- 获取特定类型的字面量值 function vm.getLiteral(source, type) - local values = vm.getValue(source) + local values = vm.getInfers(source) if not values then return nil end @@ -885,8 +885,8 @@ function vm.getLiteral(source, type) end function vm.isSameValue(a, b) - local valuesA = vm.getValue(a) - local valuesB = vm.getValue(b) + local valuesA = vm.getInfers(a) + local valuesB = vm.getInfers(b) if not valuesA or not valuesB then return false end @@ -921,13 +921,13 @@ end --- 是否包含某种类型 function vm.hasType(source, type) - local values = vm.getValue(source) - if not values then + local infers = vm.getInfers(source) + if not infers then return false end - for i = 1, #values do - local value = values[i] - if value.type == type then + for i = 1, #infers do + local infer = infers[i] + if infer.type == type then return true end end @@ -935,26 +935,15 @@ function vm.hasType(source, type) end function vm.getType(source) - local values = vm.getValue(source) - return guide.viewInfer(values) + local infers = vm.getInfers(source) + return guide.viewInfer(infers) end --- 获取对象的值 --- 会尝试穿透函数调用 -function vm.getValue(source) +function vm.getInfers(source) if not source then return end - local cache = vm.getCache('getValue')[source] - if cache ~= nil then - return cache - end - local unlock = vm.lock('getValue', source) - if not unlock then - return - end - cache = guide.requestInfer(source, vm.interface) or false - vm.getCache('getValue')[source] = cache - unlock() - return cache + return guide.requestInfer(source, vm.interface) end diff --git a/script-beta/vm/getLibrary.lua b/script-beta/vm/getLibrary.lua index 34047805..69bc3f2b 100644 --- a/script-beta/vm/getLibrary.lua +++ b/script-beta/vm/getLibrary.lua @@ -49,7 +49,7 @@ end local function getNodeAsObject(source) local node = source.node - local values = vm.getValue(node) + local values = vm.getInfers(node) if not values then return nil end diff --git a/script-beta/vm/guideInterface.lua b/script-beta/vm/guideInterface.lua index bd7a7c2b..e4c27531 100644 --- a/script-beta/vm/guideInterface.lua +++ b/script-beta/vm/guideInterface.lua @@ -95,7 +95,7 @@ function vm.interface.cache(source, mode) cache[mode] = {} end local sourceCache = cache[mode][source] - if cache[mode][source] then + if sourceCache then return sourceCache end sourceCache = {} diff --git a/script-beta/vm/init.lua b/script-beta/vm/init.lua index bc2b41c1..7621859f 100644 --- a/script-beta/vm/init.lua +++ b/script-beta/vm/init.lua @@ -1,7 +1,7 @@ local vm = require 'vm.vm' require 'vm.getGlobals' require 'vm.getLibrary' -require 'vm.getValue' +require 'vm.getInfer' require 'vm.getClass' require 'vm.getMeta' require 'vm.eachField' diff --git a/test-beta/type_inference/init.lua b/test-beta/type_inference/init.lua index 7842316e..17a2ae5b 100644 --- a/test-beta/type_inference/init.lua +++ b/test-beta/type_inference/init.lua @@ -30,7 +30,7 @@ function TEST(wanted) files.setText('', newScript) local source = getSource(pos) assert(source) - local result = vm.getType(source) or 'any' + local result = vm.getType(source) assert(wanted == result) end end |