diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2019-01-25 14:47:22 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2019-01-25 14:47:22 +0800 |
commit | 03cf351259bb4df569d7566ea6f108e200433c39 (patch) | |
tree | 7ae8ec188f162a8ddafb40ee56f240e0c9d1868f /server | |
parent | 98e8f056144cce9eaca7d39e4c56d44cd87e8b2c (diff) | |
download | lua-language-server-03cf351259bb4df569d7566ea6f108e200433c39.zip |
支持多类型推测
Diffstat (limited to 'server')
-rw-r--r-- | server/src/core/value.lua | 43 | ||||
-rw-r--r-- | server/src/core/vm.lua | 42 |
2 files changed, 58 insertions, 27 deletions
diff --git a/server/src/core/value.lua b/server/src/core/value.lua index 632f6cd4..5aa374ad 100644 --- a/server/src/core/value.lua +++ b/server/src/core/value.lua @@ -3,9 +3,8 @@ local DefaultSource = { start = 0, finish = 0 } local mt = {} mt.__index = mt mt.type = 'value' -mt._type = 'any' -function mt:setValue(source, value) +function mt:setValue(value) self._value = value end @@ -13,17 +12,34 @@ function mt:getValue() return self._value end -function mt:inference(tp) +function mt:inference(tp, rate) if tp == '...' then error('Value type cant be ...') end - if self._type == 'any' and tp ~= 'nil' then - self._type = tp + if not tp or tp == 'any' then + return + end + if not self._type then + self._type = {} + end + if not self._type[tp] or rate > self._type[tp] then + self._type[tp] = rate end end function mt:getType() - return self._type + if not self._type then + return 'nil' + end + local mRate = 0.0 + local mType + for tp, rate in pairs(self._type) do + if rate > mRate then + mRate = rate + mType = tp + end + end + return mType or 'any' end function mt:createField(name, source) @@ -42,7 +58,7 @@ function mt:createField(name, source) end self._child[uri][name] = field - self:inference('table') + self:inference('table', 0.5) return field end @@ -135,15 +151,18 @@ return function (tp, source, value) error('Value type cant be ...') end -- TODO lib里的多类型 - if type(tp) == 'table' then - tp = tp[1] - end local self = setmetatable({ source = source or DefaultSource, - _type = tp, }, mt) if value ~= nil then - self:setValue(source, value) + self:setValue(value) + end + if type(tp) == 'table' then + for i = 1, #tp do + self:inference(tp[i], 0.9) + end + else + self:inference(tp, 1.0) end return self end diff --git a/server/src/core/vm.lua b/server/src/core/vm.lua index 219c5be1..68ef0efb 100644 --- a/server/src/core/vm.lua +++ b/server/src/core/vm.lua @@ -631,20 +631,23 @@ function mt:callDoFile(func, values) end function mt:call(func, values, source) - func:inference('function') + func:inference('function', 0.9) local lib = func.lib if lib then if lib.args then for i, arg in ipairs(lib.args) do - -- TODO 反向推测调用参数的类型 + local value = values[i] + if value and arg.type ~= '...' then + value:inference(arg.type, 1.0) + end end end if lib.returns then for i, rtn in ipairs(lib.returns) do if rtn.type == '...' then - self:getFunctionReturns(func, i):inference('any') + self:getFunctionReturns(func, i):inference('any', 0.0) else - self:getFunctionReturns(func, i):inference(rtn.type or 'any') + self:getFunctionReturns(func, i):inference(rtn.type or 'any', 1.0) end end end @@ -1029,8 +1032,10 @@ function mt:getBinary(exp) or op == '<' or op == '>' then - v1:inference('number') - v2:inference('number') + v1:inference('number', 0.9) + v2:inference('number', 0.9) + v1:inference('string', 0.1) + v2:inference('string', 0.1) return self:createValue('boolean') elseif op == '~=' or op == '==' @@ -1042,8 +1047,12 @@ function mt:getBinary(exp) or op == '<<' or op == '>>' then - v1:inference('integer') - v2:inference('integer') + v1:inference('integer', 0.9) + v2:inference('integer', 0.9) + v1:inference('number', 0.9) + v2:inference('number', 0.9) + v1:inference('string', 0.1) + v2:inference('string', 0.1) if math.type(v1:getValue()) == 'integer' and math.type(v2:getValue()) == 'integer' then if op == '|' then return self:createValue('integer', v1:getValue() | v2:getValue()) @@ -1059,8 +1068,10 @@ function mt:getBinary(exp) end return self:createValue('integer') elseif op == '..' then - v1:inference('string') - v2:inference('string') + v1:inference('string', 0.9) + v2:inference('string', 0.9) + v1:inference('number', 0.1) + v2:inference('number', 0.1) if type(v1:getValue()) == 'string' and type(v2:getValue()) == 'string' then return self:createValue('string', nil, v1:getValue() .. v2:getValue()) end @@ -1073,8 +1084,8 @@ function mt:getBinary(exp) or op == '%' or op == '//' then - v1:inference('number') - v2:inference('number') + v1:inference('number', 0.9) + v2:inference('number', 0.9) if type(v1:getValue()) == 'number' and type(v2:getValue()) == 'number' then if op == '+' then return self:createValue('number', nil, v1:getValue() + v2:getValue()) @@ -1111,19 +1122,20 @@ function mt:getUnary(exp) if op == 'not' then return self:createValue('boolean') elseif op == '#' then - v1:inference('table') + v1:inference('table', 0.9) + v1:inference('string', 0.9) if type(v1:getValue()) == 'string' then return self:createValue('integer', nil, #v1:getValue()) end return self:createValue('integer') elseif op == '-' then - v1:inference('number') + v1:inference('number', 0.9) if type(v1:getValue()) == 'number' then return self:createValue('number', nil, -v1:getValue()) end return self:createValue('number') elseif op == '~' then - v1:inference('integer') + v1:inference('integer', 0.9) if math.type(v1:getValue()) == 'integer' then return self:createValue('integer', nil, ~v1:getValue()) end |