Skip to content

Commit

Permalink
修复区间未结束时的区间判断不正确的问题 (#248)
Browse files Browse the repository at this point in the history
- 当前是12号14点
- 12号下午
目前是给出下个月的12号下午,在 `beforeEndOfInterval` 有效的情况下,应该给出本月12号的。
  • Loading branch information
du00cs authored Oct 28, 2024
1 parent b7503c2 commit 65dceca
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ case class TimeData(timePred: TimePredicate,
|| hint != Hint.Recent && !options.timeOptions.alwaysInFuture)
val valueOpt =
try {
resolveTimeData(refTime, this, reverseTake)
resolveTimeData(refTime, this, reverseTake, options)
} catch {
case e: java.time.DateTimeException =>
logger.error(s"time resolve failed with DateTimeException [${e.getMessage}]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ object TimePredicateHelpers {
notImmediate: Boolean,
cyclicPred: TimePredicate,
basePred: TimePredicate): TimePredicate = {
def f(t: TimeObject, ctx: TimeContext): Option[TimeObject] = {
val (past, future) = runPredicate(cyclicPred)(t, ctx)
def f(t: TimeObject, ctx: TimeContext, options: Options): Option[TimeObject] = {
val (past, future) = runPredicate(cyclicPred)(t, ctx, null)
val rest = if (n >= 0) {
future match {
case ahead #:: _ if notImmediate && timeBefore(ahead, t) => future.drop(n + 1)
Expand All @@ -151,7 +151,7 @@ object TimePredicateHelpers {
def timeCycle(grain: Grain): CycleSeriesPredicate = timeCycle(grain, grain)

def timeCycle(grain: Grain, roundGrain: Grain, step: Int = 1): CycleSeriesPredicate = {
CycleSeriesPredicate((t: TimeObject, _: TimeContext) => {
CycleSeriesPredicate((t: TimeObject, _: TimeContext, _: Options) => {
timeSequence(grain, step, if (roundGrain != NoGrain) timeRound(t, roundGrain) else t)
}, step, grain)
}
Expand All @@ -160,10 +160,10 @@ object TimePredicateHelpers {
* Takes `n` cycles of `f`
*/
def takeN(literalN: Int, notImmediate: Boolean, cycleSP: CycleSeriesPredicate): TimePredicate = {
def series(t: TimeObject, context: TimeContext) = {
def series(t: TimeObject, context: TimeContext, options: Options) = {
val baseTime = context.refTime
// 确定起点
val (past, future) = runPredicate(cycleSP)(baseTime, context)
val (past, future) = runPredicate(cycleSP)(baseTime, context, options)
val fut = future match {
case ahead #:: rest if notImmediate && timeIntersect(ahead)(baseTime).nonEmpty => rest
case _ => future
Expand Down Expand Up @@ -200,8 +200,8 @@ object TimePredicateHelpers {
* 0 is the first element in the future
*/
def takeNth(n: Int, notImmediate: Boolean, f: TimePredicate): TimePredicate = {
val series = (t: TimeObject, context: TimeContext) => {
val (past, future) = runPredicate(f)(context.refTime, context)
val series = (t: TimeObject, context: TimeContext, options: Options) => {
val (past, future) = runPredicate(f)(context.refTime, context, options)
val rest = if (n >= 0) {
future match {
case Stream.Empty => Stream.Empty
Expand Down Expand Up @@ -232,7 +232,7 @@ object TimePredicateHelpers {
}

def solarTermPredicate(term: String): SeriesPredicate = {
val series: SeriesPredicateF = (t: TimeObject, context: TimeContext) => {
val series: SeriesPredicateF = (t: TimeObject, context: TimeContext, options: Options) => {
if (!containSolarTerm(t.start.year, term)) (Stream.empty, Stream.empty)
else {
def f(step: Int)(to: TimeObject): TimeObject = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
package com.xiaomi.duckling.dimension

import com.github.heqiao2010.lunar.{LunarCalendar, LunarData}

import java.time.LocalTime

import com.xiaomi.duckling.Types.{ZoneCN, conf}
import com.xiaomi.duckling.Types.{conf, Options, ZoneCN}
import com.xiaomi.duckling.dimension.time.Types.{TimeContext, TimeObject, _}
import com.xiaomi.duckling.dimension.time.enums.AMPM._
import com.xiaomi.duckling.dimension.time.enums.Grain._
Expand All @@ -36,7 +35,7 @@ package object time {
/**
* Return a tuple of (past, future) elements
*/
type SeriesPredicateF = (TimeObject, TimeContext) => PastFutureTime
type SeriesPredicateF = (TimeObject, TimeContext, Options) => PastFutureTime

implicit class GrainWrapper(grain: Grain) {
def <(that: Grain): Boolean = grain.compareTo(that) < 0
Expand Down Expand Up @@ -82,11 +81,12 @@ package object time {

def resolveTimeData(refTime: TimeObject,
td: TimeData,
reverseTake: Boolean = false): Option[TimeObject] = {
reverseTake: Boolean = false,
options: Options): Option[TimeObject] = {

val tc = refTimeContext(refTime, reverseTake)

val (past, future) = runPredicate(td.timePred)(refTime, tc)
val (past, future) = runPredicate(td.timePred)(refTime, tc, options)

val reverse = if (reverseTake) {
future match {
Expand Down Expand Up @@ -118,7 +118,8 @@ package object time {
// 1. 过了一部分还需要再出的,12号 => 2/12,2月 => 2013/2
// 2. 问4点,需要给出 16:00
val g = if (td.timeGrain >= Grain.Day) td.timeGrain else Grain.NoGrain
timeBefore(ahead, refTime, g)
if (options.timeOptions.beforeEndOfInterval) endBefore(ahead, refTime, g)
else timeBefore(ahead, refTime, g)
case TimeIntervalsPredicate(_, _, _, beforeEndOfInterval) =>
val g = if (td.timeGrain >= Grain.Day) td.timeGrain else Grain.NoGrain
if (!beforeEndOfInterval) timeBefore(ahead, refTime, g)
Expand All @@ -134,7 +135,7 @@ package object time {
}

val EmptySeries: PastFutureTime = (Stream.empty, Stream.empty)
val EmptySeriesPredicate: SeriesPredicateF = (_: TimeObject, _: TimeContext) => EmptySeries
val EmptySeriesPredicate: SeriesPredicateF = (_: TimeObject, _: TimeContext, _: Options) => EmptySeries

def runPredicate(tp: TimePredicate): SeriesPredicateF = {
tp match {
Expand All @@ -153,9 +154,9 @@ package object time {
year.map(runYearPredicate)
).flatten

def series(t: TimeObject, tc: TimeContext): PastFutureTime = {
def series(t: TimeObject, tc: TimeContext, options: Options): PastFutureTime = {
val pred = toCompose.reduceOption(runCompose).getOrElse(EmptySeriesPredicate)
val (past, future) = pred(t, tc)
val (past, future) = pred(t, tc, options)
(past, future)
}

Expand All @@ -172,7 +173,7 @@ package object time {
}
}

def runEndOfGrainPredicate(t: TimeObject, context: TimeContext): PastFutureTime = {
def runEndOfGrainPredicate(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
val (start, grain) = t.grain match {
case Grain.Month =>
(t.start.plusMonths(1).plusDays(-1), Day)
Expand All @@ -186,10 +187,10 @@ package object time {
def runReplacePartPredicate(
td1: TimeData,
td2: TimeData
)(t: TimeObject, context: TimeContext): PastFutureTime = {
)(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
(for {
t1 <- resolveTimeData(t, td1)
t2 <- resolveTimeData(t, td2)
t1 <- resolveTimeData(t, td1, options = options)
t2 <- resolveTimeData(t, td2, options = options)
} yield {
val to =
if (td2.timePred.maxGrain.nonEmpty && td1.timeGrain > td2.timePred.maxGrain.get) {
Expand Down Expand Up @@ -255,14 +256,14 @@ package object time {

@scala.annotation.tailrec
def runSequencePredicate(list: List[TimeData])(t: TimeObject,
context: TimeContext): PastFutureTime = {
context: TimeContext, options: Options): PastFutureTime = {
list match {
case Nil => (Stream.empty, Stream(context.refTime))
case td :: xs =>
resolveTimeData(t, td) match {
resolveTimeData(t, td, options = options) match {
case Some(refTime) =>
val tc = refTimeContext(refTime)
runSequencePredicate(xs)(refTime, tc)
runSequencePredicate(xs)(refTime, tc, options)
case None => EmptySeries
}
}
Expand All @@ -272,13 +273,13 @@ package object time {
runCompose(runPredicate(pred1), runPredicate(pred2))
}

def runSecondPredicate(n: Int)(t: TimeObject, context: TimeContext): PastFutureTime = {
def runSecondPredicate(n: Int)(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
val s = t.start.second
val anchor = timePlus(timeRound(t, Second), Second, n - s % 60)
timeSequence(Minute, 1, anchor)
}

def runMinutePredicate(n: Int)(t: TimeObject, context: TimeContext): PastFutureTime = {
def runMinutePredicate(n: Int)(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
val rounded = timeRound(t, Minute)
val m = t.start.minute
val anchor = timePlus(rounded, Minute, (n - m) % 60)
Expand All @@ -287,7 +288,7 @@ package object time {

def runHourPredicate(
ampm: Option[AMPM]
)(hour: (Boolean, Int))(t: TimeObject, context: TimeContext): PastFutureTime = {
)(hour: (Boolean, Int))(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
val (is12H, n) = hour
val step = if (is12H && n <= 12 && ampm.isEmpty) 12 else 24
val nAdjust = ampm match {
Expand All @@ -308,13 +309,13 @@ package object time {
)
}

def runDayOfTheWeekPredicate(n: Int)(t: TimeObject, context: TimeContext): PastFutureTime = {
def runDayOfTheWeekPredicate(n: Int)(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
val daysUntilNextWeek = Math.floorMod(n - t.start.dayOfWeek, 7)
val anchor = timePlus(timeRound(t, Day), Day, daysUntilNextWeek)
timeSequence(Day, 7, anchor)
}

def runDayOfTheMonthPredicate(n: Int)(t: TimeObject, context: TimeContext): PastFutureTime = {
def runDayOfTheMonthPredicate(n: Int)(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {

def enoughDays(t: TimeObject): Boolean = {
n <= t.start.date.lengthOfMonth
Expand All @@ -334,7 +335,7 @@ package object time {
(past, future)
}

def runMonthPredicate(calendar: Option[Calendar])(n: Int)(t: TimeObject, context: TimeContext): PastFutureTime = {
def runMonthPredicate(calendar: Option[Calendar])(n: Int)(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
val y = timeRound(t, Year, calendar)
val rounded =
calendar match {
Expand All @@ -347,7 +348,7 @@ package object time {
timeSequence(Year, 1, anchor)
}

def runYearPredicate(n: Int)(t: TimeObject, context: TimeContext): PastFutureTime = {
def runYearPredicate(n: Int)(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
val year = n
val tyear = t.start.year
val y = timePlus(timeRound(t, Year), Year, year - tyear)
Expand All @@ -362,14 +363,14 @@ package object time {
* Performs best when pred1 is smaller grain than pred2
*/
def runCompose(pred1: SeriesPredicateF, pred2: SeriesPredicateF): SeriesPredicateF = {
val series = (nowTime: TimeObject, context: TimeContext) => {
val (past, future) = pred2(nowTime, context)
val series = (nowTime: TimeObject, context: TimeContext, options: Options) => {
val (past, future) = pred2(nowTime, context, options)

def startsBefore(t1: TimeObject)(t: TimeObject): Boolean = timeStartsBeforeTheEndOf(t)(t1)

def computeSeries(tokens: Stream[TimeObject]): Stream[TimeObject] = {
tokens.take(safeMax).flatMap { time1 =>
val (past, future) = pred1(time1, fixedTimeContext(time1))
val (past, future) = pred1(time1, fixedTimeContext(time1), options)
val before = future.takeWhile(startsBefore(time1))
before.flatMap(timeIntersect(time1))
}
Expand All @@ -387,8 +388,8 @@ package object time {
pred2: TimePredicate,
beforeEndOfInterval: Boolean): SeriesPredicateF = {
// Pick the first interval *after* the given time segment
def f(thisSegment: TimeObject, ctx: TimeContext): Option[TimeObject] = {
runPredicate(pred2)(thisSegment, ctx) match {
def f(thisSegment: TimeObject, ctx: TimeContext, options: Options): Option[TimeObject] = {
runPredicate(pred2)(thisSegment, ctx, options) match {
case (_, firstFuture #:: tail) =>
// 避免9点-9点,左右一样(空区间)
val end = if (firstFuture != thisSegment || tail.headOption.isEmpty) firstFuture else tail.head
Expand All @@ -397,8 +398,8 @@ package object time {
}
}

def b(thisSegment: TimeObject, ctx: TimeContext): Option[TimeObject] = {
runPredicate(pred1)(thisSegment, ctx) match {
def b(thisSegment: TimeObject, ctx: TimeContext, options: Options): Option[TimeObject] = {
runPredicate(pred1)(thisSegment, ctx, options) match {
case (past, future) =>
val choosed = future.take(safeMax).find(t => timeStartsBeforeTheEndOf(t)(thisSegment))
.orElse(past.take(safeMax).find(t => timeStartsBeforeTheEndOf(t)(thisSegment)))
Expand All @@ -425,15 +426,15 @@ package object time {
* @return Series generator for values that come from `f`
*/
def timeSeqMap(dontReverse: Boolean,
f: (TimeObject, TimeContext) => Option[TimeObject],
f: (TimeObject, TimeContext, Options) => Option[TimeObject],
g: TimePredicate): SeriesPredicateF = {
def seriesF(nowTime: TimeObject, context: TimeContext) = {
def seriesF(nowTime: TimeObject, context: TimeContext, options: Options) = {
// computes a single interval from `f` based on each interval in the series
def applyF(series: Stream[TimeObject]) = {
series.take(safeMaxInterval).flatMap(f(_, context))
series.take(safeMaxInterval).flatMap(f(_, context, options))
}

val (firstPast, firstFuture) = runPredicate(g)(nowTime, context)
val (firstPast, firstFuture) = runPredicate(g)(nowTime, context, options)
val (past1, future1) = (applyF(firstPast), applyF(firstFuture))

// Separate what's before and after now from the past's series
Expand Down Expand Up @@ -466,7 +467,7 @@ package object time {
case _ => false
}

def runTimeOpenIntervalPredicate(it: IntervalDirection)(t: TimeObject, context: TimeContext): PastFutureTime = {
def runTimeOpenIntervalPredicate(it: IntervalDirection)(t: TimeObject, context: TimeContext, options: Options): PastFutureTime = {
(Stream(t.copy(direction = Some(it))), Stream.empty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ import com.xiaomi.duckling.dimension.time.helper.TimeObjectHelpers.{timeIntersec
import com.xiaomi.duckling.dimension.time.Types._
import com.xiaomi.duckling.dimension.time.enums.Grain
import com.xiaomi.duckling.ranking.Testing
import com.xiaomi.duckling.Types.ZoneCN
import com.xiaomi.duckling.Types.{Options, ZoneCN}
import com.xiaomi.duckling.UnitSpec

class TypesTest extends UnitSpec {

describe("TypesTest") {

def round1(refTime: TimeObject, td: TimeData): Option[TimeObject] = {
def round1(refTime: TimeObject, td: TimeData, options: Options): Option[TimeObject] = {
val tc = TimeContext(
refTime = refTime,
maxTime = timePlus(refTime, Grain.Year, 2000),
minTime = timePlus(refTime, Grain.Year, -2000)
)
val (past, future) = runPredicate(td.timePred)(refTime, tc)
val (past, future) = runPredicate(td.timePred)(refTime, tc, options)

val valueOpt = future match {
case Stream.Empty => past.headOption
Expand All @@ -53,13 +53,14 @@ class TypesTest extends UnitSpec {

it("sequence apply demo") {
val refTime = new TimeObject(Testing.testContext.referenceTime, Grain.Second)
val options = Options()
val td1 = cycleNth(Day, 1)

val r1 = round1(refTime, td1).get
val r1 = round1(refTime, td1, options).get
r1.start.dayOfMonth shouldBe 13

val td2 = cycleNth(Day, 2)
val r2 = round1(r1, td2).get
val r2 = round1(r1, td2, options).get
r2.start.dayOfMonth shouldBe 15
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object NaiveBayesDebug {
options.timeOptions.setResetTimeOfDay(false)
options.timeOptions.setRecentInFuture(true)
options.timeOptions.setAlwaysInFuture(true)
options.timeOptions.setBeforeEndOfInterval(false)
options.timeOptions.setBeforeEndOfInterval(true)
options.numeralOptions.setAllowZeroLeadingDigits(false)
options.numeralOptions.setCnSequenceAsNumber(false)

Expand Down

0 comments on commit 65dceca

Please sign in to comment.