diff options
-rw-r--r-- | script/parser/guide.lua | 27 | ||||
-rw-r--r-- | script/parser/luadoc.lua | 18 | ||||
-rw-r--r-- | script/vm/compiler.lua | 1 | ||||
-rw-r--r-- | script/vm/global.lua | 1 | ||||
-rw-r--r-- | script/vm/operator.lua | 55 | ||||
-rw-r--r-- | test/type_inference/init.lua | 27 |
6 files changed, 94 insertions, 35 deletions
diff --git a/script/parser/guide.lua b/script/parser/guide.lua index f782c43a..0768cbb4 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -158,6 +158,7 @@ local childMap = { ['doc.as'] = {'as'}, ['doc.cast'] = {'loc', '#casts'}, ['doc.cast.block'] = {'extends'}, + ['doc.operator'] = {'op', 'exp', 'extends'} } ---@type table<string, fun(obj: parser.object, list: parser.object[])> @@ -250,32 +251,6 @@ m.actionMap = { ['funcargs'] = {'#'}, } -local inf = 1 / 0 -local nan = 0 / 0 - -local function isInteger(n) - if math.type then - return math.type(n) == 'integer' - else - return type(n) == 'number' and n % 1 == 0 - end -end - -local function formatNumber(n) - if n == inf - or n == -inf - or n == nan - or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则 - return ('%q'):format(n) - end - if isInteger(n) then - return tostring(n) - end - local str = ('%.10f'):format(n) - str = str:gsub('%.?0*$', '') - return str -end - --- 是否是字面量 ---@param obj table ---@return boolean diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index a94f89cd..2237b232 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -819,8 +819,9 @@ local docSwitch = util.switch() : case 'class' : call(function () local result = { - type = 'doc.class', - fields = {}, + type = 'doc.class', + fields = {}, + operators = {}, } result.class = parseName('doc.class.name', result) if not result.class then @@ -1386,8 +1387,8 @@ local docSwitch = util.switch() local ret = parseType(result) if ret then - result.ret = ret - result.finish = ret.finish + result.extends = ret + result.finish = ret.finish end return result @@ -1497,8 +1498,10 @@ local function isContinuedDoc(lastDoc, nextDoc) return false end if lastDoc.type == 'doc.class' - or lastDoc.type == 'doc.field' then + or lastDoc.type == 'doc.field' + or lastDoc.type == 'doc.operator' then if nextDoc.type ~= 'doc.field' + and nextDoc.type ~= 'doc.operator' and nextDoc.type ~= 'doc.comment' and nextDoc.type ~= 'doc.overload' then return false @@ -1650,6 +1653,11 @@ local function bindClassAndFields(binded) class.fields[#class.fields+1] = doc doc.class = class end + elseif doc.type == 'doc.operator' then + if class then + class.operators[#class.operators+1] = doc + doc.class = class + end end end end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index beb6e4e5..a13f1e3a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -12,6 +12,7 @@ local vm = require 'vm.vm' ---@field _globalBase table ---@field cindex integer ---@field func parser.object +---@field operators? parser.object[] -- 该函数有副作用,会给source绑定node! ---@param source parser.object diff --git a/script/vm/global.lua b/script/vm/global.lua index b94e2768..d1bd798f 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -42,6 +42,7 @@ function mt:addGet(uri, source) self.getsCache = nil end +---@param suri uri ---@return parser.object[] function mt:getSets(suri) if not self.setsCache then diff --git a/script/vm/operator.lua b/script/vm/operator.lua index 0ed3ff1d..67d039e4 100644 --- a/script/vm/operator.lua +++ b/script/vm/operator.lua @@ -25,6 +25,47 @@ vm.BINARY_OP = { 'concat', } +---@param operators parser.object[] +---@param op string +---@param value? parser.object +---@param result? vm.node +---@return vm.node? +local function checkOperators(operators, op, value, result) + for _, operator in ipairs(operators) do + if operator.op[1] ~= op + or not operator.extends then + goto CONTINUE + end + if not result then + result = vm.createNode() + end + result:merge(vm.compileNode(operator.extends)) + ::CONTINUE:: + end + return result +end + +---@param op string +---@param exp parser.object +---@param value? parser.object +---@return vm.node? +function vm.runOperator(op, exp, value) + local uri = guide.getUri(exp) + local node = vm.compileNode(exp) + local result + for c in node:eachObject() do + if c.type == 'global' and c.cate == 'type' then + ---@cast c vm.global + for _, set in ipairs(c:getSets(uri)) do + if set.operators and #set.operators > 0 then + result = checkOperators(set.operators, op, value, result) + end + end + end + end + return result +end + vm.unarySwich = util.switch() : case 'not' : call(function (source) @@ -43,17 +84,22 @@ vm.unarySwich = util.switch() end) : case '#' : call(function (source) - vm.setNode(source, vm.declareGlobal('type', 'integer')) + local node = vm.runOperator('len', source[1]) + vm.setNode(source, node or vm.declareGlobal('type', 'integer')) end) : case '-' : call(function (source) local v = vm.getNumber(source[1]) if v == nil then + local uri = guide.getUri(source) local infer = vm.getInfer(source[1]) - if infer:hasType(guide.getUri(source), 'integer') then + if infer:hasType(uri, 'integer') then vm.setNode(source, vm.declareGlobal('type', 'integer')) - else + elseif infer:hasType(uri, 'number') then vm.setNode(source, vm.declareGlobal('type', 'number')) + else + local node = vm.runOperator('unm', source[1]) + vm.setNode(source, node or vm.declareGlobal('type', 'number')) end else vm.setNode(source, { @@ -69,7 +115,8 @@ vm.unarySwich = util.switch() : call(function (source) local v = vm.getInteger(source[1]) if v == nil then - vm.setNode(source, vm.declareGlobal('type', 'integer')) + local node = vm.runOperator('bnot', source[1]) + vm.setNode(source, node or vm.declareGlobal('type', 'integer')) else vm.setNode(source, { type = 'integer', diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index bae14650..f31a8777 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -3391,3 +3391,30 @@ TEST 'unknown' [[ mt = {} mt.<?x?> = nil ]] + +TEST 'A' [[ +---@class A +---@operator unm: A + +---@type A +local a +local <?b?> = -a +]] + +TEST 'A' [[ +---@class A +---@operator bnot: A + +---@type A +local a +local <?b?> = ~a +]] + +TEST 'A' [[ +---@class A +---@operator len: A + +---@type A +local a +local <?b?> = #a +]] |