summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script-beta/core/diagnostics/redundant-parameter.lua2
-rw-r--r--script-beta/core/hover/init.lua2
-rw-r--r--script-beta/core/hover/label.lua18
-rw-r--r--script-beta/core/hover/return.lua2
-rw-r--r--script-beta/core/hover/table.lua3
-rw-r--r--script-beta/parser/guide.lua175
-rw-r--r--script-beta/vm/getClass.lua2
-rw-r--r--script-beta/vm/getInfer.lua (renamed from script-beta/vm/getValue.lua)65
-rw-r--r--script-beta/vm/getLibrary.lua2
-rw-r--r--script-beta/vm/guideInterface.lua2
-rw-r--r--script-beta/vm/init.lua2
-rw-r--r--test-beta/type_inference/init.lua2
12 files changed, 135 insertions, 142 deletions
diff --git a/script-beta/core/diagnostics/redundant-parameter.lua b/script-beta/core/diagnostics/redundant-parameter.lua
index d619ba10..b424c2bf 100644
--- a/script-beta/core/diagnostics/redundant-parameter.lua
+++ b/script-beta/core/diagnostics/redundant-parameter.lua
@@ -65,7 +65,7 @@ return function (uri, callback)
if not vm.hasType(func, 'function') then
return
end
- local values = vm.getValue(func)
+ local values = vm.getInfers(func)
for _, value in ipairs(values) do
if value.type and value.source.type == 'function' then
local args = countFuncArgs(value.source)
diff --git a/script-beta/core/hover/init.lua b/script-beta/core/hover/init.lua
index 007e8aad..e123bb06 100644
--- a/script-beta/core/hover/init.lua
+++ b/script-beta/core/hover/init.lua
@@ -7,7 +7,7 @@ local util = require 'utility'
local findSource = require 'core.find-source'
local function getHoverAsFunction(source)
- local values = vm.getValue(source)
+ local values = vm.getInfers(source)
local desc = getDesc(source)
local labels = {}
local defs = 0
diff --git a/script-beta/core/hover/label.lua b/script-beta/core/hover/label.lua
index 5e14c68e..a776f0c4 100644
--- a/script-beta/core/hover/label.lua
+++ b/script-beta/core/hover/label.lua
@@ -18,17 +18,19 @@ end
local function asValue(source, title)
local name = buildName(source)
- local class, type, literal, cont
- local values = vm.getValue(source)
+ local class = 'any'
+ local type = 'any'
+ local literal, cont
+ local values = vm.getInfers(source)
if values then
for _, value in ipairs(values) do
local src = value.source
local tp = value.type
- class = guide.mergeInfers(class, vm.getClass(src))
- type = guide.mergeInfers(type, tp)
+ class = guide.mergeTypes {class, vm.getClass(src)}
+ type = guide.mergeTypes {type, tp}
local sl = vm.getLiteral(src)
if sl then
- literal = guide.mergeInfers(literal, util.viewLiteral(sl))
+ literal = guide.mergeTypes {literal, util.viewLiteral(sl)}
end
if tp == 'table' then
cont = buildTable(src)
@@ -36,11 +38,11 @@ local function asValue(source, title)
end
end
vm.eachDef(source, function (src)
- class = guide.mergeInfers(class, vm.getClass(src))
- type = guide.mergeInfers(type, vm.getType(src))
+ class = guide.mergeTypes {class, vm.getClass(src)}
+ type = guide.mergeTypes {type, vm.getType(src)}
local sl = vm.getLiteral(src)
if sl then
- literal = guide.mergeInfers(literal, util.viewLiteral(sl))
+ literal = guide.mergeTypes {literal, util.viewLiteral(sl)}
end
if type == 'table' then
cont = buildTable(src)
diff --git a/script-beta/core/hover/return.lua b/script-beta/core/hover/return.lua
index 82dc1314..f67a961f 100644
--- a/script-beta/core/hover/return.lua
+++ b/script-beta/core/hover/return.lua
@@ -36,7 +36,7 @@ local function asFunction(source)
local returns = {}
for _, rtn in ipairs(source.returns) do
for i = 1, #rtn do
- local values = vm.getValue(rtn[i])
+ local values = vm.getInfers(rtn[i])
returns[#returns+1] = values
end
break
diff --git a/script-beta/core/hover/table.lua b/script-beta/core/hover/table.lua
index d00440a7..5b086cd6 100644
--- a/script-beta/core/hover/table.lua
+++ b/script-beta/core/hover/table.lua
@@ -1,5 +1,6 @@
local vm = require 'vm'
local util = require 'utility'
+local guide = require 'parser.guide'
local function getKey(src)
local key = vm.getKeyName(src)
@@ -118,7 +119,7 @@ return function (source)
local intValue = true
vm.eachField(source, function (src)
local key, class, literal = getField(src)
- classes[key] = guide.mergeInfers(class, classes[key])
+ classes[key] = guide.mergeTypes {class, classes[key]}
literals[key] = mergeLiteral(literal, literals[key])
if class ~= 'integer'
or not literals[key]
diff --git a/script-beta/parser/guide.lua b/script-beta/parser/guide.lua
index 228b0a61..4f15a675 100644
--- a/script-beta/parser/guide.lua
+++ b/script-beta/parser/guide.lua
@@ -1423,8 +1423,6 @@ function m.checkSameSimpleAsReturn(status, ref, start, queue)
if ref.parent.type ~= 'return' then
return
end
- -- TODO 这里的开销非常大
- --do return end
if ref.parent.parent.type ~= 'main' then
return
end
@@ -1786,7 +1784,7 @@ function m.cleanResults(results)
end
end
-function m.getCache(status, obj, mode)
+function m.getRefCache(status, obj, mode)
if not status.interface.cache then
return
end
@@ -1802,7 +1800,7 @@ end
function m.searchRefs(status, obj, mode)
status.depth = status.depth + 1
- local cache, makeCache = m.getCache(status, obj, mode)
+ local cache, makeCache = m.getRefCache(status, obj, mode)
if cache then
for i = 1, #cache do
status.results[#status.results+1] = cache[i]
@@ -1844,23 +1842,6 @@ function m.searchRefOfValue(status, obj)
end
end
-function m.mergeInfer(t, b)
- if not t then
- t = {}
- end
- if not b then
- return t
- end
- for i = 1, #b do
- local o = b[i]
- if not t[o] then
- t[o] = true
- t[#t+1] = o
- end
- end
- return t
-end
-
function m.allocInfer(o)
-- TODO
assert(o.type)
@@ -1873,31 +1854,26 @@ function m.allocInfer(o)
source = o.source,
}
values[i] = sub
- values[sub] = true
end
return values
else
return {
[1] = o,
- [o] = true,
}
end
end
-function m.insertInfer(t, o)
- if not o then
- return
- end
- if not t[o] then
- t[o] = true
- t[#t+1] = o
+function m.mergeTypes(infers)
+ local types = {}
+ for i = 1, #infers do
+ for tp in infers[i]:gmatch '[^|]+' do
+ if not types[tp] and tp ~= 'any' then
+ types[#types+1] = tp
+ end
+ end
end
- return t
-end
-
-local function mergeInfers(types)
if #types == 0 then
- return nil
+ return 'any'
end
if #types == 1 then
return types[1]
@@ -1922,23 +1898,6 @@ local function mergeInfers(types)
return tableConcat(types, '|')
end
-function m.mergeInfers(...)
- local max = select('#', ...)
- local views = {}
- for i = 1, max do
- local view = select(i, ...)
- if view then
- for tp in view:gmatch '[^|]+' do
- if not views[tp] and tp ~= 'any' then
- views[tp] = true
- views[#views+1] = tp
- end
- end
- end
- end
- return mergeInfers(views)
-end
-
function m.viewInfer(infers)
if not infers then
return 'any'
@@ -1946,102 +1905,142 @@ function m.viewInfer(infers)
if type(infers) ~= 'table' then
return infers or 'any'
end
+ local mark = {}
local types = {}
for i = 1, #infers do
local tp = infers[i].type
- if tp and not types[tp] and tp ~= 'any' then
- types[tp] = true
+ if not mark[tp] and tp ~= 'any' then
types[#types+1] = tp
end
+ mark[tp] = true
end
- return m.mergeInfers(types) or 'any'
+ return m.mergeTypes(types)
end
function m.inferCheckLiteral(status, source)
if source.type == 'string' then
- return m.alloc {
+ return m.allocInfer {
type = 'string',
value = source[1],
source = source,
}
elseif source.type == 'nil' then
- return m.alloc {
+ return m.allocInfer {
type = 'nil',
value = NIL,
source = source,
}
elseif source.type == 'boolean' then
- return m.alloc {
+ return m.allocInfer {
type = 'boolean',
value = source[1],
source = source,
}
elseif source.type == 'number' then
if mathType(source[1]) == 'integer' then
- return m.alloc {
+ return m.allocInfer {
type = 'integer',
value = source[1],
source = source,
}
else
- return m.alloc {
+ return m.allocInfer {
type = 'number',
value = source[1],
source = source,
}
end
elseif source.type == 'integer' then
- return m.alloc {
+ return m.allocInfer {
type = 'integer',
source = source,
}
elseif source.type == 'table' then
- return m.alloc {
+ return m.allocInfer {
type = 'table',
source = source,
}
elseif source.type == 'function' then
- return m.alloc {
+ return m.allocInfer {
type = 'function',
source = source,
}
elseif source.type == '...' then
- return m.alloc {
+ return m.allocInfer {
type = '...',
source = source,
}
end
end
+function m.inferByDef(status, obj)
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, obj, 'def')
+ for _, src in ipairs(newStatus.results) do
+ local inferStatus = m.status(status)
+ local infers = m.searchInfer(inferStatus, src)
+
+ end
+end
+
+function m.cleanInfers(infers)
+ local mark = {}
+ for i = 1, #infers do
+ local source = infers[i].source
+ if mark[source] then
+ infers[i] = infers[#infers]
+ infers[#infers] = nil
+ else
+ mark[source] = true
+ end
+ end
+end
+
function m.searchInfer(status, obj)
obj = m.getObjectValue(obj) or obj
+
+ local cache, makeCache
+ if status.interface.cache then
+ cache, makeCache = status.interface.cache(obj, 'infer')
+ end
+ if cache then
+ for i = 1, #cache do
+ status.results[#status.results+1] = cache[i]
+ end
+ return
+ end
+
local results = m.inferCheckLiteral(status, obj)
- --or inferCheckUnary(obj)
- --or inferCheckBinary(obj)
- --or inferCheckLibraryTypes(obj)
- --or inferCheckLibrary(obj)
- --or inferCheckSpecialReturn(obj)
- --or inferCheckLibraryReturn(obj)
+ --or m.inferCheckUnary(obj)
+ --or m.inferCheckBinary(obj)
+ --or m.inferCheckLibraryTypes(obj)
+ --or m.inferCheckLibrary(obj)
+ --or m.inferCheckSpecialReturn(obj)
+ --or m.inferCheckLibraryReturn(obj)
if results then
- return results
+ m.cleanInfers(results)
+ for i = 1, #results do
+ status.results[#status.results+1] = results[i]
+ end
+ if makeCache then
+ makeCache(status.results)
+ end
+ return
end
- results = {}
- --inferByLibraryArg(results, obj)
- --inferByDef(results, source)
- --inferBySet(results, obj)
- --inferByCall(results, obj)
- --inferByGetTable(results, obj)
- --inferByUnary(results, obj)
- --inferByBinary(results, obj)
- --inferByCallReturn(results, obj)
- --inferByPCallReturn(results, obj)
-
- if #results == 0 then
- return nil
+ --inferByLibraryArg(status, obj)
+ m.inferByDef(status, obj)
+ --m.inferBySet(status, obj)
+ --m.inferByCall(status, obj)
+ --m.inferByGetTable(status, obj)
+ --m.inferByUnary(status, obj)
+ --m.inferByBinary(status, obj)
+ --m.inferByCallReturn(status, obj)
+ --m.inferByPCallReturn(status, obj)
+ m.cleanInfers(status.results)
+ if makeCache then
+ makeCache(status.results)
end
-
- return results
end
--- 请求对象的引用,包括 `a.b.c` 形式
@@ -2081,7 +2080,9 @@ end
--- 请求对象的类型推测
function m.requestInfer(obj, interface)
local status = m.status(nil, interface)
- return m.searchInfer(status, obj)
+ m.searchInfer(status, obj)
+
+ return status.results, status.cache.count
end
return m
diff --git a/script-beta/vm/getClass.lua b/script-beta/vm/getClass.lua
index 0d5d45b1..a3394e41 100644
--- a/script-beta/vm/getClass.lua
+++ b/script-beta/vm/getClass.lua
@@ -41,5 +41,5 @@ function vm.getClass(source)
if #classes == 0 then
return nil
end
- return guide.mergeInfers(table.unpack(classes))
+ return guide.mergeTypes(classes)
end
diff --git a/script-beta/vm/getValue.lua b/script-beta/vm/getInfer.lua
index bc83dc10..e98ac1a0 100644
--- a/script-beta/vm/getValue.lua
+++ b/script-beta/vm/getInfer.lua
@@ -189,25 +189,25 @@ local function checkBinary(source)
if op.type == 'and' then
local isTrue = vm.checkTrue(source[1])
if isTrue == true then
- return vm.getValue(source[2])
+ return vm.getInfers(source[2])
elseif isTrue == false then
- return vm.getValue(source[1])
+ return vm.getInfers(source[1])
else
return merge(
- vm.getValue(source[1]),
- vm.getValue(source[2])
+ vm.getInfers(source[1]),
+ vm.getInfers(source[2])
)
end
elseif op.type == 'or' then
local isTrue = vm.checkTrue(source[1])
if isTrue == true then
- return vm.getValue(source[1])
+ return vm.getInfers(source[1])
elseif isTrue == false then
- return vm.getValue(source[2])
+ return vm.getInfers(source[2])
else
return merge(
- vm.getValue(source[1]),
- vm.getValue(source[2])
+ vm.getInfers(source[1]),
+ vm.getInfers(source[2])
)
end
elseif op.type == '==' then
@@ -711,7 +711,7 @@ local function inferBySetOfLocal(results, source)
if ref.type == 'setlocal' then
break
end
- merge(results, vm.getValue(ref))
+ merge(results, vm.getInfers(ref))
end
end
end
@@ -724,7 +724,7 @@ local function inferBySet(results, source)
inferBySetOfLocal(results, source)
elseif source.type == 'setlocal'
or source.type == 'getlocal' then
- merge(results, vm.getValue(source.node))
+ merge(results, vm.getInfers(source.node))
end
end
@@ -736,7 +736,7 @@ local function mergeFunctionReturns(results, source, index)
for i = 1, #returns do
local rtn = returns[i]
if rtn[index] then
- merge(results, vm.getValue(rtn[index]))
+ merge(results, vm.getInfers(rtn[index]))
end
end
end
@@ -749,7 +749,7 @@ local function inferByCallReturn(results, source)
return
end
local node = source.vararg.node
- local nodeValues = vm.getValue(node)
+ local nodeValues = vm.getInfers(node)
if not nodeValues then
return
end
@@ -786,7 +786,7 @@ local function inferByPCallReturn(results, source)
else
return
end
- local funcValues = vm.getValue(func)
+ local funcValues = vm.getInfers(func)
if not funcValues then
return
end
@@ -817,7 +817,7 @@ function vm.inferValue(source, infer)
results = {}
inferByLibraryArg(results, source)
- --inferByDef(results, source)
+ inferByDef(results, source)
inferBySet(results, source)
inferByCall(results, source)
inferByGetTable(results, source)
@@ -834,7 +834,7 @@ function vm.inferValue(source, infer)
end
function vm.checkTrue(source)
- local values = vm.getValue(source)
+ local values = vm.getInfers(source)
if not values then
return
end
@@ -869,7 +869,7 @@ end
--- 获取特定类型的字面量值
function vm.getLiteral(source, type)
- local values = vm.getValue(source)
+ local values = vm.getInfers(source)
if not values then
return nil
end
@@ -885,8 +885,8 @@ function vm.getLiteral(source, type)
end
function vm.isSameValue(a, b)
- local valuesA = vm.getValue(a)
- local valuesB = vm.getValue(b)
+ local valuesA = vm.getInfers(a)
+ local valuesB = vm.getInfers(b)
if not valuesA or not valuesB then
return false
end
@@ -921,13 +921,13 @@ end
--- 是否包含某种类型
function vm.hasType(source, type)
- local values = vm.getValue(source)
- if not values then
+ local infers = vm.getInfers(source)
+ if not infers then
return false
end
- for i = 1, #values do
- local value = values[i]
- if value.type == type then
+ for i = 1, #infers do
+ local infer = infers[i]
+ if infer.type == type then
return true
end
end
@@ -935,26 +935,15 @@ function vm.hasType(source, type)
end
function vm.getType(source)
- local values = vm.getValue(source)
- return guide.viewInfer(values)
+ local infers = vm.getInfers(source)
+ return guide.viewInfer(infers)
end
--- 获取对象的值
--- 会尝试穿透函数调用
-function vm.getValue(source)
+function vm.getInfers(source)
if not source then
return
end
- local cache = vm.getCache('getValue')[source]
- if cache ~= nil then
- return cache
- end
- local unlock = vm.lock('getValue', source)
- if not unlock then
- return
- end
- cache = guide.requestInfer(source, vm.interface) or false
- vm.getCache('getValue')[source] = cache
- unlock()
- return cache
+ return guide.requestInfer(source, vm.interface)
end
diff --git a/script-beta/vm/getLibrary.lua b/script-beta/vm/getLibrary.lua
index 34047805..69bc3f2b 100644
--- a/script-beta/vm/getLibrary.lua
+++ b/script-beta/vm/getLibrary.lua
@@ -49,7 +49,7 @@ end
local function getNodeAsObject(source)
local node = source.node
- local values = vm.getValue(node)
+ local values = vm.getInfers(node)
if not values then
return nil
end
diff --git a/script-beta/vm/guideInterface.lua b/script-beta/vm/guideInterface.lua
index bd7a7c2b..e4c27531 100644
--- a/script-beta/vm/guideInterface.lua
+++ b/script-beta/vm/guideInterface.lua
@@ -95,7 +95,7 @@ function vm.interface.cache(source, mode)
cache[mode] = {}
end
local sourceCache = cache[mode][source]
- if cache[mode][source] then
+ if sourceCache then
return sourceCache
end
sourceCache = {}
diff --git a/script-beta/vm/init.lua b/script-beta/vm/init.lua
index bc2b41c1..7621859f 100644
--- a/script-beta/vm/init.lua
+++ b/script-beta/vm/init.lua
@@ -1,7 +1,7 @@
local vm = require 'vm.vm'
require 'vm.getGlobals'
require 'vm.getLibrary'
-require 'vm.getValue'
+require 'vm.getInfer'
require 'vm.getClass'
require 'vm.getMeta'
require 'vm.eachField'
diff --git a/test-beta/type_inference/init.lua b/test-beta/type_inference/init.lua
index 7842316e..17a2ae5b 100644
--- a/test-beta/type_inference/init.lua
+++ b/test-beta/type_inference/init.lua
@@ -30,7 +30,7 @@ function TEST(wanted)
files.setText('', newScript)
local source = getSource(pos)
assert(source)
- local result = vm.getType(source) or 'any'
+ local result = vm.getType(source)
assert(wanted == result)
end
end