UVA 1401: Remember the Word —— 动态规划 + 字典树

题目链接 ( UVA 这个网站好像要先登录才能看题)

题目大意:

给你一个长字符串 $str$ 和另外 $s$ 个短字符串,要将 $str$ 用给出的 $s$ 个字符串组合起来。问有多少种组合方式。

输入格式:

多组数据。

每组数据第一行一个字符串 $str$ 。下一行一个整数 $s$ 。接下来的 $s$ 行每一行一个字符串。

输出格式:

输出答案对 $20071027$ 取模的结果。格式与样例一致。

测试样例:

Input
abcd
4
a
b
cd
ab
Output
Case 1: 2

样例解释:

字符串 abcd 可以用 a, b, cd 组合,也可以用 ab, cd 组合。所以是两种。

分析:

动态规划。

用 $dp_i$ 表示 以 $str_i$ 开头 的字符串可以被分解的数量。我们从 $i$ 开始往后枚举,枚举到 $j$ 的时候,若 $str_i$ -> $str_j$ 正好是给出的 $s$ 个字符串中的一个完整的字符串,那么我们就可以更新 $dp_i$:

dp[i] += dp[j + 1]; 

那么如何判断 $str_i$ -> $strj$ 是否是一个完整的字符串呢。我们可以考虑用字典树。若 $str_j$ 在树上的节点被打了标记,就说明有一个字符串以这个字母结尾,也就是一个完整的字符串了。

最后要说的是,这道题的字典树好像比较适合用数组写,我用邻接表要超时…(因为邻接表在查找某一个节点有没有某一个儿子的时候应该必须要枚举)

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <vector>
using namespace std;

template<class _T> inline void read(_T& _x) {
  int _t, _flag = 0;
  while ((_t = getchar()) != '-' && (_t < '0' || _t > '9'));
  if (_t == '-')  _flag = 1, _t = getchar();
  _x = _t - '0';
  while ((_t = getchar()) >= '0' && _t <= '9')
    _x = _x * 10 + _t - '0';
  if (_flag)  _x = -_x;
}

typedef long long ll;
typedef pair<int, int> pii;

const int MAX_N = 3 * 1e5 + 100, MOD = 20071027;
char str[MAX_N], t[MAX_N];
int s, vis[MAX_N], dp[MAX_N];
vector<int> ans[MAX_N];

struct Trie {
  int ch[MAX_N][28], sz, val[MAX_N];
  void init() {
    sz = 1;
    val[0] = 0;
    memset(ch[0], 0, sizeof(ch[0]));
  }
  void insert(char* s) {
    int len = strlen(s), u = 0;
    for (int i = 0; i < len; ++i) {
      int c = s[i] - 'a';
      if (!ch[u][c]) {
        memset(ch[sz], 0, sizeof(ch[sz]));
        val[sz] = 0;
        ch[u][c] = sz++;
      }
      u = ch[u][c];
    }
    ++val[u];
  }
} trie;

void DP() {
  int len = strlen(str);
  dp[len] = 1;
  for (int i = len - 1; i >= 0; --i) {
    int now = 0, j = i;
    while (j <= len - 1) {
      if (trie.ch[now][str[j] - 'a']) {
        now = trie.ch[now][str[j] - 'a'];
        if (trie.val[now]) {
          dp[i] += dp[j + 1];
          dp[i] %= MOD;
        }
        ++j;
      }
      else break;
    }
  }
}

int main() {
  int T = 0;
  while (~scanf("%s", str)) {
    trie.init();
    memset(dp, 0, sizeof(dp));
    read(s);
    for (int i = 1; i <= s; ++i) {
      scanf("%s", t);
      trie.insert(t);
    }
    DP();
    printf("Case %d: %d\n", ++T, dp[0]);
  }
}