HachiML commited on
Commit
0142c7d
1 Parent(s): c2792a3

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +15 -0
modeling_moment.py CHANGED
@@ -503,6 +503,21 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
503
  input_mask = torch.ones_like(time_series_values[:, 0, :])
504
 
505
  return self.embed(x_enc=time_series_values, input_mask=input_mask, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
 
508
  # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L601
 
503
  input_mask = torch.ones_like(time_series_values[:, 0, :])
504
 
505
  return self.embed(x_enc=time_series_values, input_mask=input_mask, **kwargs)
506
+
507
+ def calculate_n_patches(self, seq_len: int) -> int:
508
+ """
509
+ 時系列の長さ(seq_len)を与えて、モデルのself.patch_lenとself.strideを使ってn_patchesを計算して返します。
510
+ strideがNoneの場合はpatch_lenを使用します。
511
+
512
+ Args:
513
+ seq_len (int): 時系列の長さ
514
+
515
+ Returns:
516
+ int: 計算されたn_patchesの数
517
+ """
518
+ stride = self.stride if self.stride is not None else self.patch_len
519
+ n_patches = (seq_len - self.patch_len) // stride + 1
520
+ return n_patches
521
 
522
 
523
  # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L601