「题解」 P5608 [Ynoi2013] 文化课

AtomAlpaca

Table of contents

题意

link

简述题意:给定由数字和加、乘两种符号组成的表达式,支持区间覆盖数,区间覆盖符号,区间查询表达式值。

题解

线段树好题。学了线段树就能写。比较考验码力。

首先考虑节点上要维护哪些信息。

首先不考虑修改,只考虑快速合并两个区间,我们要记录每个节点表示区间的表达式值 vl ;考虑两个区间交界处的符号是乘号时合并不能简单地把两个答案加起来,我们还要维护区间最右侧的符号 op,左侧、右侧最长的一段连乘段的乘积 lmrm

考虑修改符号。首先要设置两个懒标记 atmt 分别代表覆盖为加、乘。考虑覆盖加后,整个节点的值变为了所有元素的和,我们还要维护区间和 sm;对于乘法,我们同理维护一个区间积 sm。另外,区间覆盖加还会把左、右端最长连乘段变得只剩下一个数,因此还要维护区间最左、右侧值 lvrv

修改数字稍微麻烦一些。我们考虑以加号为分界,把每个区间拆成若干连乘段,长度为 lenlen,区间改数字就是将区间值改为 xx 的若干个 lenlen 次幂的和。因此我们考虑开一个 vector 存下每个区间内的所有区间连乘段的长度。为了方便区间覆盖加,我们还可以维护一个区间长度 ln,不过这并不是必要的。

这时候发现,合并的两个区间的分界处是乘号时,对合并成的新区间内的连乘段也有影响,于是考虑再维护一个最左、右端连乘段长度 lcrc

好!现在弄明白了维护哪些信息,具体的维护方法也呼之欲出了!

每次将一个节点覆盖加号,将 op 改为加,vl 改为 smlcrc 变为 1lmrm 分别改为 lvrv,然后把区间长度 lenlen1 扔进 vector 里即可。

区间覆盖乘号、数字和区间覆盖加号的原理基本相同,这里不加赘述。

然后是区间合并。当分界处符号(也就是左区间最右侧符号 op)为加的时候,显然新区间的值就是两个区间的值加起来;否则应当分别去掉两个区间的右、左极长连乘段,在把连乘段的乘积加上;其次将两个区间的 vector 合并起来,极长左乘右乘段跨区间时类似区间值处理即可。

但是这样写,你会发现会 T 掉一些点,我们考虑优化。

首先我们发现查询答案的时候很多信息其实是不需要维护的,比如区间内的 vector、区间和、区间积,最左、右值等等。于是我们可以另写一个区间合并来省下维护这些信息的时间。

然后我们发现区间内长度不同的数量稍微有些多,区间覆盖数字时挨个做快速幂还是不够快,于是我们考虑把这些段排序,然后从小到大扫,一路乘过去。要让区间有序,只需要在合并两个区间的时候把两个 vector 并归一下即可。

而且 vector 的速度也太慢了,但是我们区间内开桶又开不下,于是可以考虑根号分治,区间维护一个大小为 len\sqrt{len} 长度的桶,把 vector 里小于 len\sqrt{len} 的数扔到这个桶里,剩下的数不会超过 len\sqrt{len} 个,可以保证效率。为了省空间这个数组可以每个节点 new 一个出来,代替等长的数组。

代码

代码实现十分繁琐。要处处记得取模。

值域比较小的变量就不要开 long long,会 MLE。

#include <bits/stdc++.h>

using std::cin;
using std::cout;
using std::vector;

typedef long long ll;
typedef vector<ll> V;
typedef vector<ll>::iterator iter;

const int MOD = 1e9 + 7;
const ll MAX = 1e5 + 5;

ll n, m, l, r, op, x;
ll num[MAX], opr[MAX];

ll qp(ll a, ll x)
{
  ll res = 1;
  while (x)
  {
    if (x & 1) { res = res * a % MOD; }
    a = a * a % MOD; x >>= 1;
  }
  return res % MOD;
}

struct T
{
  bool at, mt, op;
  int lv, rv;
  int * c; V v;
  ll ln, sq;
  ll sm, ml, lm, rm, lc, rc, vl, ag;
  void clear()
  {
    for (int i = 1; i <= sq; ++i) { c[i] = 0; } v.clear();
  }

  void asgadd()
  {
    vl = sm % MOD; op = 0; lc = rc = 1; lm = lv, rm = rv;
    clear(); c[1] += ln;
    at = true; mt = false;
  }

  void asgmul()
  {
    vl = ml % MOD; op = 1; lc = rc = ln; lm = rm = ml % MOD;
    clear(); if (ln <= sq) { ++c[ln]; } else { v.push_back(ln); }
    mt = true; at = false;
  }

  void asg(ll v)
  {
    v %= MOD; lv = rv = v; vl = 0;
    lm = qp(v, lc); rm = qp(v, rc); ml = qp(v, ln); sm = v * ln % MOD;
    ll now = v;
    for (int i = 1; i <= sq; ++i, now = now * v % MOD)
    {
      if (c[i])
      {   
        vl = (vl + 1ll * now * c[i] % MOD) % MOD;
      }
    }
    for (int i : this -> v)
    {
      vl = (vl + qp(v, i)) % MOD; 
    }
    ag = v;
  }

  void merge(T & l, T & r)
  {
    sm = (l.sm + r.sm) % MOD;
    ml = 1ll * l.ml * r.ml % MOD;
    lv = l.lv, rv = r.rv;
    op = r.op;
    lm = l.lm % MOD, rm = r.rm % MOD, lc = l.lc, rc = r.rc;
    if (l.op and l.lc == l.ln) { lc = l.ln + r.lc; lm = 1ll * l.ml * r.lm % MOD; }
    if (l.op and r.rc == r.ln) { rc = r.ln + l.rc; rm = 1ll * r.ml * l.rm % MOD; }
    vl = (l.vl + r.vl) % MOD;
    if (l.op)
    {
      vl = ((((l.vl + r.vl) % MOD + 1ll * l.rm * r.lm % MOD) % MOD - (l.rm + r.lm) % MOD) % MOD + MOD) % MOD;
    }
    clear();
    for (int i = 1; i <= l.sq; ++i) { c[i] += l.c[i]; }
    for (int i = 1; i <= r.sq; ++i) { c[i] += r.c[i]; }
    int i1 = 0, i2 = 0, e1 = l.v.size(), e2 = r.v.size();
    while (i1 < e1 and i2 < e2)
    {
      if (l.v[i1] < r.v[i2])
      {
        if (l.v[i1] <= sq) { c[l.v[i1]]++; }
        else { v.push_back(l.v[i1]); }
        ++i1;
      }
      else
      {
        if (r.v[i2] <= sq) { c[r.v[i2]]++; }
        else { v.push_back(r.v[i2]); }
        ++i2;
      }
    }
    while (i1 < e1)
    {
      if (l.v[i1] <= sq) { c[l.v[i1]]++; }
      else { v.push_back(l.v[i1]); }
      ++i1;
    }
    while (i2 < e2)
    {
      if (r.v[i2] <= sq) { c[r.v[i2]]++; }
      else { v.push_back(r.v[i2]); }
      ++i2;
    }   
    if (l.op)
    {
      if (l.rc <= sq) { --c[l.rc]; }
      else { iter i = lower_bound(v.begin(), v.end(), l.rc); v.erase(i); }
      if (r.lc <= sq) { --c[r.lc]; }
      else { iter i = lower_bound(v.begin(), v.end(), r.lc); v.erase(i); }
      if (l.rc + r.lc <= sq) { ++c[l.rc + r.lc]; }
      else { iter i = lower_bound(v.begin(), v.end(), l.rc + r.lc); v.insert(i, l.rc + r.lc); }
    }
  }
  
  void pushdown(T & l, T & r)
  {
    if (ag) { l.asg(ag); r.asg(ag); ag = 0; } 
    if (at) { l.asgadd(); r.asgadd(); at = false; }
    if (mt) { l.asgmul(); r.asgmul(); mt = false; }
  }
} a[MAX << 2 | 1];

struct A
{
  ll lm, rm, lc, rc, vl, ln, op;
};

A merge(A l, A r)
{
  A x;
  x.op = r.op; x.ln = l.ln + r.ln;
  x.lc = l.lc; x.rc = r.rc;
  x.lm = l.lm; x.rm = r.rm;
  if (l.op and l.lc == l.ln) { x.lc = l.ln + r.lc; x.lm = 1ll * l.lm * r.lm % MOD; }
  if (l.op and r.rc == r.ln) { x.rc = r.ln + l.rc; x.rm = 1ll * r.rm * l.rm % MOD; }
  x.vl = (l.vl + r.vl) % MOD;
  if (l.op)
  {
    x.vl = (((x.vl + 1ll * l.rm * r.lm) % MOD - (l.rm + r.lm) % MOD) % MOD + MOD) % MOD;
  }
  return x;
}

A query(int l, int r, int s, int t, int x)
{
  if (l >= s and r <= t) { return {a[x].lm, a[x].rm, a[x].lc, a[x].rc, a[x].vl % MOD, a[x].ln, a[x].op}; }
  a[x].pushdown(a[x << 1], a[x << 1 | 1]);
  int k = l + ((r - l) >> 1);
  if (k >= t) { return query(l, k, s, t, x << 1); }
  else if (k <  s) { return query(k + 1, r, s, t, x << 1 | 1); }
  else { return merge(query(l, k, s, t, x << 1), query(k + 1, r, s, t, x << 1 | 1)); }
}

void build(int l, int r, int x)
{
  a[x].ln = r - l + 1; a[x].sq = sqrt(a[x].ln); a[x].c = new int [a[x].sq + 3]();
  if (l == r)
  {
    a[x].vl = a[x].lv = a[x].rv = a[x].lm = a[x].rm = a[x].sm = a[x].ml = num[l];
    a[x].op = opr[l]; a[x].lc = a[x].rc = 1; a[x].c[1] = 1;
    return ;
  }
  int k = l + ((r - l) >> 1);
  build(l, k, x << 1); build(k + 1, r, x << 1 | 1);
  a[x].merge(a[x << 1], a[x << 1 | 1]);
}

void asg(int l, int r, int s, int t, ll v, int x)
{
  if (l >= s and r <= t) { a[x].asg(v); return ; }
  a[x].pushdown(a[x << 1], a[x << 1 | 1]);
  int k = l + ((r - l) >> 1);
  if (k >= s) { asg(l, k, s, t, v, x << 1); }
  if (k <  t) { asg(k + 1, r, s, t, v, x << 1 | 1); }
  a[x].merge(a[x << 1], a[x << 1 | 1]);
}

void asgadd(int l, int r, int s, int t, int x)
{
  if (l >= s and r <= t) { a[x].asgadd(); return ; }
  a[x].pushdown(a[x << 1], a[x << 1 | 1]);
  int k = l + ((r - l) >> 1);
  if (k >= s) { asgadd(l, k, s, t, x << 1); }
  if (k <  t) { asgadd(k + 1, r, s, t, x << 1 | 1); }
  a[x].merge(a[x << 1], a[x << 1 | 1]);
}

void asgmul(int l, int r, int s, int t, int x)
{
  if (l >= s and r <= t) { a[x].asgmul(); return ; }
  a[x].pushdown(a[x << 1], a[x << 1 | 1]);
  int k = l + ((r - l) >> 1);
  if (k >= s) { asgmul(l, k, s, t, x << 1); }
  if (k <  t) { asgmul(k + 1, r, s, t, x << 1 | 1); }
  a[x].merge(a[x << 1], a[x << 1 | 1]);
}


int main()
{
  cin.tie(NULL); cout.tie(NULL); std::ios::sync_with_stdio(false);
  cin >> n >> m;
  for (int i = 1; i <= n; ++i) { cin >> num[i]; num[i] %= MOD; }
  for (int i = 1; i <  n; ++i) { cin >> opr[i]; }
  build(1, n, 1);
  while (m--)
  {
    cin >> op >> l >> r;
    if (op != 3) { cin >> x; }
    if (op == 1) { asg(1, n, l, r, x % MOD, 1); }
    else if (op == 2)
    {
      if (x == 0) { asgadd(1, n, l, r, 1); }
      else { asgmul(1, n, l, r, 1); }
    }
    else { cout << query(1, n, l, r, 1).vl % MOD << '\n'; }
  }
}